diff --git a/tests/restful_client/base/alias_service.py b/tests/restful_client/README.md similarity index 100% rename from tests/restful_client/base/alias_service.py rename to tests/restful_client/README.md diff --git a/tests/restful_client/api/alias.py b/tests/restful_client/api/alias.py deleted file mode 100644 index b67a08d0a7..0000000000 --- a/tests/restful_client/api/alias.py +++ /dev/null @@ -1,17 +0,0 @@ -from decorest import GET, POST, DELETE -from decorest import HttpStatus, RestClient -from decorest import accept, body, content, endpoint, form -from decorest import header, multipart, on, query, stream, timeout - - -class Alias(RestClient): - - def drop_alias(): - pass - - def alter_alias(): - pass - - def create_alias(): - pass - \ No newline at end of file diff --git a/tests/restful_client/api/collection.py b/tests/restful_client/api/collection.py deleted file mode 100644 index a5f542dfac..0000000000 --- a/tests/restful_client/api/collection.py +++ /dev/null @@ -1,62 +0,0 @@ -import json - -from decorest import GET, POST, DELETE -from decorest import HttpStatus, RestClient -from decorest import accept, body, content, endpoint, form -from decorest import header, multipart, on, query, stream, timeout - - -class Collection(RestClient): - - @DELETE("collection") - @body("payload", lambda p: json.dumps(p)) - @on(200, lambda r: r.json()) - def drop_collection(self, payload): - """Drop a collection""" - - @GET("collection") - @body("payload", lambda p: json.dumps(p)) - @on(200, lambda r: r.json()) - def describe_collection(self, payload): - """Describe a collection""" - - @POST("collection") - @body("payload", lambda p: json.dumps(p)) - @on(200, lambda r: r.json()) - def create_collection(self, payload): - """Create a collection""" - - @GET("collection/existence") - @body("payload", lambda p: json.dumps(p)) - @on(200, lambda r: r.json()) - def has_collection(self, payload): - """Check if a collection exists""" - - @DELETE("collection/load") - @body("payload", lambda p: json.dumps(p)) - @on(200, lambda r: r.json()) - def release_collection(self, payload): - """Release a collection""" - - @POST("collection/load") - @body("payload", lambda p: json.dumps(p)) - @on(200, lambda r: r.json()) - def load_collection(self, payload): - """Load a collection""" - - @GET("collection/statistics") - @body("payload", lambda p: json.dumps(p)) - @on(200, lambda r: r.json()) - def get_collection_statistics(self, payload): - """Get collection statistics""" - - @GET("collections") - @body("payload", lambda p: json.dumps(p)) - @on(200, lambda r: r.json()) - def show_collections(self, payload): - """Show collections""" - - -if __name__ == '__main__': - client = Collection("http://localhost:19121/api/v1") - print(client) diff --git a/tests/restful_client/api/credential.py b/tests/restful_client/api/credential.py deleted file mode 100644 index eb4cc3b535..0000000000 --- a/tests/restful_client/api/credential.py +++ /dev/null @@ -1,19 +0,0 @@ -from decorest import GET, POST, DELETE -from decorest import HttpStatus, RestClient -from decorest import accept, body, content, endpoint, form -from decorest import header, multipart, on, query, stream, timeout - - -class Credential(RestClient): - - def delete_credential(): - pass - - def update_credential(): - pass - - def create_credential(): - pass - - def list_credentials(): - pass diff --git a/tests/restful_client/api/entity.py b/tests/restful_client/api/entity.py deleted file mode 100644 index fc7a98dd93..0000000000 --- a/tests/restful_client/api/entity.py +++ /dev/null @@ -1,63 +0,0 @@ -import json -from decorest import GET, POST, DELETE -from decorest import HttpStatus, RestClient -from decorest import accept, body, content, endpoint, form -from decorest import header, multipart, on, query, stream, timeout - - -class Entity(RestClient): - - @POST("distance") - @body("payload", lambda p: json.dumps(p)) - @on(200, lambda r: r.json()) - def calc_distance(self, payload): - """ Calculate distance between two points """ - - @DELETE("entities") - @body("payload", lambda p: json.dumps(p)) - @on(200, lambda r: r.json()) - def delete(self, payload): - """delete entities""" - - @POST("entities") - @body("payload", lambda p: json.dumps(p)) - @on(200, lambda r: r.json()) - def insert(self, payload): - """insert entities""" - - @POST("persist") - @body("payload", lambda p: json.dumps(p)) - @on(200, lambda r: r.json()) - def flush(self, payload): - """flush entities""" - - @POST("persist/segment-info") - @body("payload", lambda p: json.dumps(p)) - @on(200, lambda r: r.json()) - def get_persistent_segment_info(self, payload): - """get persistent segment info""" - - @POST("persist/state") - @body("payload", lambda p: json.dumps(p)) - @on(200, lambda r: r.json()) - def get_flush_state(self, payload): - """get flush state""" - - @POST("query") - @body("payload", lambda p: json.dumps(p)) - @on(200, lambda r: r.json()) - def query(self, payload): - """query entities""" - - @POST("query-segment-info") - @body("payload", lambda p: json.dumps(p)) - @on(200, lambda r: r.json()) - def get_query_segment_info(self, payload): - """get query segment info""" - - @POST("search") - @body("payload", lambda p: json.dumps(p)) - @on(200, lambda r: r.json()) - def search(self, payload): - """search entities""" - \ No newline at end of file diff --git a/tests/restful_client/api/import.py b/tests/restful_client/api/import.py deleted file mode 100644 index 4b0ed2f2cb..0000000000 --- a/tests/restful_client/api/import.py +++ /dev/null @@ -1,18 +0,0 @@ -from decorest import GET, POST, DELETE -from decorest import HttpStatus, RestClient -from decorest import accept, body, content, endpoint, form -from decorest import header, multipart, on, query, stream, timeout - -class Import(RestClient): - - def list_import_tasks(): - pass - - def exec_import(): - pass - - def get_import_state(): - pass - - - diff --git a/tests/restful_client/api/index.py b/tests/restful_client/api/index.py deleted file mode 100644 index c0c4f6a341..0000000000 --- a/tests/restful_client/api/index.py +++ /dev/null @@ -1,38 +0,0 @@ -import json -from decorest import GET, POST, DELETE -from decorest import HttpStatus, RestClient -from decorest import accept, body, content, endpoint, form -from decorest import header, multipart, on, query, stream, timeout - - -class Index(RestClient): - - @DELETE("/index") - @body("payload", lambda p: json.dumps(p)) - @on(200, lambda r: r.json()) - def drop_index(self, payload): - """Drop an index""" - - @GET("/index") - @body("payload", lambda p: json.dumps(p)) - @on(200, lambda r: r.json()) - def describe_index(self, payload): - """Describe an index""" - - @POST("index") - @body("payload", lambda p: json.dumps(p)) - @on(200, lambda r: r.json()) - def create_index(self, payload): - """create index""" - - @GET("index/progress") - @body("payload", lambda p: json.dumps(p)) - @on(200, lambda r: r.json()) - def get_index_build_progress(self, payload): - """get index build progress""" - - @GET("index/state") - @body("payload", lambda p: json.dumps(p)) - @on(200, lambda r: r.json()) - def get_index_state(self, payload): - """get index state""" diff --git a/tests/restful_client/api/metrics.py b/tests/restful_client/api/metrics.py deleted file mode 100644 index c02a9c7f4f..0000000000 --- a/tests/restful_client/api/metrics.py +++ /dev/null @@ -1,11 +0,0 @@ -from decorest import GET, POST, DELETE -from decorest import HttpStatus, RestClient -from decorest import accept, body, content, endpoint, form -from decorest import header, multipart, on, query, stream, timeout - - -class Metrics(RestClient): - - def get_metrics(): - pass - \ No newline at end of file diff --git a/tests/restful_client/api/milvus.py b/tests/restful_client/api/milvus.py new file mode 100644 index 0000000000..ba8cc9a665 --- /dev/null +++ b/tests/restful_client/api/milvus.py @@ -0,0 +1,257 @@ +import json +import requests +import time +import uuid +from utils.util_log import test_log as logger + + +def logger_request_response(response, url, tt, headers, data, str_data, str_response, method): + if len(data) > 2000: + data = data[:1000] + "..." + data[-1000:] + try: + if response.status_code == 200: + if ('code' in response.json() and response.json()["code"] == 200) or ('Code' in response.json() and response.json()["Code"] == 0): + logger.debug( + f"method: {method}, url: {url}, cost time: {tt}, header: {headers}, payload: {str_data}, response: {str_response}") + else: + logger.error( + f"method: {method}, url: {url}, cost time: {tt}, header: {headers}, payload: {data}, response: {response.text}") + else: + logger.error( + f"method: {method}, url: {url}, cost time: {tt}, header: {headers}, payload: {data}, response: {response.text}") + except Exception as e: + logger.error(e) + logger.error( + f"method: {method}, url: {url}, cost time: {tt}, header: {headers}, payload: {data}, response: {response.text}") + + +class Requests: + def __init__(self, url=None, api_key=None): + self.url = url + self.api_key = api_key + self.headers = { + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {self.api_key}', + 'RequestId': str(uuid.uuid1()) + } + + def update_headers(self): + headers = { + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {self.api_key}', + 'RequestId': str(uuid.uuid1()) + } + return headers + + def post(self, url, headers=None, data=None): + headers = headers if headers is not None else self.update_headers() + data = json.dumps(data) + str_data = data[:200] + '...' + data[-200:] if len(data) > 400 else data + t0 = time.time() + response = requests.post(url, headers=headers, data=data) + tt = time.time() - t0 + str_response = response.text[:200] + '...' + response.text[-200:] if len(response.text) > 400 else response.text + logger_request_response(response, url, tt, headers, data, str_data, str_response, "post") + return response + + def get(self, url, headers=None, params=None, data=None): + headers = headers if headers is not None else self.update_headers() + data = json.dumps(data) + str_data = data[:200] + '...' + data[-200:] if len(data) > 400 else data + t0 = time.time() + if data is None or data == "null": + response = requests.get(url, headers=headers, params=params) + else: + response = requests.get(url, headers=headers, params=params, data=data) + tt = time.time() - t0 + str_response = response.text[:200] + '...' + response.text[-200:] if len(response.text) > 400 else response.text + logger_request_response(response, url, tt, headers, data, str_data, str_response, "get") + return response + + def put(self, url, headers=None, data=None): + headers = headers if headers is not None else self.update_headers() + data = json.dumps(data) + str_data = data[:200] + '...' + data[-200:] if len(data) > 400 else data + t0 = time.time() + response = requests.put(url, headers=headers, data=data) + tt = time.time() - t0 + str_response = response.text[:200] + '...' + response.text[-200:] if len(response.text) > 400 else response.text + logger_request_response(response, url, tt, headers, data, str_data, str_response, "put") + return response + + def delete(self, url, headers=None, data=None): + headers = headers if headers is not None else self.update_headers() + data = json.dumps(data) + str_data = data[:200] + '...' + data[-200:] if len(data) > 400 else data + t0 = time.time() + response = requests.delete(url, headers=headers, data=data) + tt = time.time() - t0 + str_response = response.text[:200] + '...' + response.text[-200:] if len(response.text) > 400 else response.text + logger_request_response(response, url, tt, headers, data, str_data, str_response, "delete") + return response + + +class VectorClient(Requests): + def __init__(self, url, api_key, protocol="http"): + super().__init__(url, api_key) + self.protocol = protocol + self.url = url + self.api_key = api_key + self.db_name = None + self.headers = self.update_headers() + + def update_headers(self): + headers = { + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {self.api_key}', + 'RequestId': str(uuid.uuid1()) + } + return headers + + def vector_search(self, payload, db_name="default", timeout=10): + time.sleep(1) + url = f'{self.protocol}://{self.url}/vector/search' + if self.db_name is not None: + payload["dbName"] = self.db_name + if db_name != "default": + payload["dbName"] = db_name + response = self.post(url, headers=self.update_headers(), data=payload) + rsp = response.json() + if "data" in rsp and len(rsp["data"]) == 0: + t0 = time.time() + while time.time() - t0 < timeout: + response = self.post(url, headers=self.update_headers(), data=payload) + rsp = response.json() + if len(rsp["data"]) > 0: + break + time.sleep(1) + else: + response = self.post(url, headers=self.update_headers(), data=payload) + rsp = response.json() + if "data" in rsp and len(rsp["data"]) == 0: + logger.info(f"after {timeout}s, still no data") + + return response.json() + + def vector_query(self, payload, db_name="default", timeout=10): + time.sleep(1) + url = f'{self.protocol}://{self.url}/vector/query' + if self.db_name is not None: + payload["dbName"] = self.db_name + if db_name != "default": + payload["dbName"] = db_name + response = self.post(url, headers=self.update_headers(), data=payload) + rsp = response.json() + if "data" in rsp and len(rsp["data"]) == 0: + t0 = time.time() + while time.time() - t0 < timeout: + response = self.post(url, headers=self.update_headers(), data=payload) + rsp = response.json() + if len(rsp["data"]) > 0: + break + time.sleep(1) + else: + response = self.post(url, headers=self.update_headers(), data=payload) + rsp = response.json() + if "data" in rsp and len(rsp["data"]) == 0: + logger.info(f"after {timeout}s, still no data") + + return response.json() + + def vector_get(self, payload, db_name="default"): + time.sleep(1) + url = f'{self.protocol}://{self.url}/vector/get' + if self.db_name is not None: + payload["dbName"] = self.db_name + if db_name != "default": + payload["dbName"] = db_name + response = self.post(url, headers=self.update_headers(), data=payload) + return response.json() + + def vector_delete(self, payload, db_name="default"): + url = f'{self.protocol}://{self.url}/vector/delete' + if self.db_name is not None: + payload["dbName"] = self.db_name + if db_name != "default": + payload["dbName"] = db_name + response = self.post(url, headers=self.update_headers(), data=payload) + return response.json() + + def vector_insert(self, payload, db_name="default"): + url = f'{self.protocol}://{self.url}/vector/insert' + if self.db_name is not None: + payload["dbName"] = self.db_name + if db_name != "default": + payload["dbName"] = db_name + response = self.post(url, headers=self.update_headers(), data=payload) + return response.json() + + +class CollectionClient(Requests): + + def __init__(self, url, api_key, protocol="http"): + super().__init__(url, api_key) + self.protocol = protocol + self.url = url + self.api_key = api_key + self.db_name = None + self.headers = self.update_headers() + + def update_headers(self): + headers = { + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {self.api_key}', + 'RequestId': str(uuid.uuid1()) + } + return headers + + def collection_list(self, db_name="default"): + url = f'{self.protocol}://{self.url}/vector/collections' + params = {} + if self.db_name is not None: + params = { + "dbName": self.db_name + } + if db_name != "default": + params = { + "dbName": db_name + } + response = self.get(url, headers=self.update_headers(), params=params) + res = response.json() + return res + + def collection_create(self, payload, db_name="default"): + time.sleep(1) # wait for collection created and in case of rate limit + url = f'{self.protocol}://{self.url}/vector/collections/create' + if self.db_name is not None: + payload["dbName"] = self.db_name + if db_name != "default": + payload["dbName"] = db_name + response = self.post(url, headers=self.update_headers(), data=payload) + return response.json() + + def collection_describe(self, collection_name, db_name="default"): + url = f'{self.protocol}://{self.url}/vector/collections/describe' + params = {"collectionName": collection_name} + if self.db_name is not None: + params = { + "collectionName": collection_name, + "dbName": self.db_name + } + if db_name != "default": + params = { + "collectionName": collection_name, + "dbName": db_name + } + response = self.get(url, headers=self.update_headers(), params=params) + return response.json() + + def collection_drop(self, payload, db_name="default"): + time.sleep(1) # wait for collection drop and in case of rate limit + url = f'{self.protocol}://{self.url}/vector/collections/drop' + if self.db_name is not None: + payload["dbName"] = self.db_name + if db_name != "default": + payload["dbName"] = db_name + response = self.post(url, headers=self.update_headers(), data=payload) + return response.json() diff --git a/tests/restful_client/api/ops.py b/tests/restful_client/api/ops.py deleted file mode 100644 index 4859c334f0..0000000000 --- a/tests/restful_client/api/ops.py +++ /dev/null @@ -1,21 +0,0 @@ -from decorest import GET, POST, DELETE -from decorest import HttpStatus, RestClient -from decorest import accept, body, content, endpoint, form -from decorest import header, multipart, on, query, stream, timeout - -class Ops(RestClient): - - def manual_compaction(): - pass - - def get_compaction_plans(): - pass - - def get_compaction_state(): - pass - - def load_balance(): - pass - - def get_replicas(): - pass \ No newline at end of file diff --git a/tests/restful_client/api/partition.py b/tests/restful_client/api/partition.py deleted file mode 100644 index 845b646617..0000000000 --- a/tests/restful_client/api/partition.py +++ /dev/null @@ -1,27 +0,0 @@ -from decorest import GET, POST, DELETE -from decorest import HttpStatus, RestClient -from decorest import accept, body, content, endpoint, form -from decorest import header, multipart, on, query, stream, timeout - - -class Partition(RestClient): - def drop_partition(): - pass - - def create_partition(): - pass - - def has_partition(): - pass - - def get_partition_statistics(): - pass - - def show_partitions(): - pass - - def release_partition(): - pass - - def load_partition(): - pass diff --git a/tests/restful_client/api/pydantic_demo.py b/tests/restful_client/api/pydantic_demo.py deleted file mode 100644 index 43d901f222..0000000000 --- a/tests/restful_client/api/pydantic_demo.py +++ /dev/null @@ -1,43 +0,0 @@ -from datetime import date, datetime -from typing import List, Union, Optional - -from pydantic import BaseModel, UUID4, conlist - -from pydantic_factories import ModelFactory - - -class Person(BaseModel): - def __init__(self, length): - super().__init__() - self.len = length - id: UUID4 - name: str - hobbies: List[str] - age: Union[float, int] - birthday: Union[datetime, date] - - -class Pet(BaseModel): - name: str - age: int - - -class PetFactory(BaseModel): - name: str - pet: Pet - age: Optional[int] = None - - -sample = { - "name": "John", - "pet": { - "name": "Fido", - "age": 3 - } - -} - -result = PetFactory(**sample) - -print(result) - diff --git a/tests/restful_client/base/client_base.py b/tests/restful_client/base/client_base.py deleted file mode 100644 index 2e52f88bc5..0000000000 --- a/tests/restful_client/base/client_base.py +++ /dev/null @@ -1,72 +0,0 @@ -from time import sleep - -from decorest import HttpStatus, RestClient -from models.schema import CollectionSchema -from base.collection_service import CollectionService -from base.index_service import IndexService -from base.entity_service import EntityService - -from utils.util_log import test_log as log -from common import common_func as cf -from common import common_type as ct - -class Base: - """init base class""" - - endpoint = None - collection_service = None - index_service = None - entity_service = None - collection_name = None - collection_object_list = [] - - def setup_class(self): - log.info("setup class") - - def teardown_class(self): - log.info("teardown class") - - def setup_method(self, method): - log.info(("*" * 35) + " setup " + ("*" * 35)) - log.info("[setup_method] Start setup test case %s." % method.__name__) - host = cf.param_info.param_host - port = cf.param_info.param_port - self.endpoint = "http://" + host + ":" + str(port) + "/api/v1" - self.collection_service = CollectionService(self.endpoint) - self.index_service = IndexService(self.endpoint) - self.entity_service = EntityService(self.endpoint) - - def teardown_method(self, method): - res = self.collection_service.has_collection(collection_name=self.collection_name) - log.info(f"collection {self.collection_name} exists: {res}") - if res["value"] is True: - res = self.collection_service.drop_collection(self.collection_name) - log.info(f"drop collection {self.collection_name} res: {res}") - res = self.collection_service.show_collections() - all_collections = res["collection_names"] - union_collections = set(all_collections) & set(self.collection_object_list) - for collection in union_collections: - res = self.collection_service.drop_collection(collection) - log.info(f"drop collection {collection} res: {res}") - log.info("[teardown_method] Start teardown test case %s." % method.__name__) - log.info(("*" * 35) + " teardown " + ("*" * 35)) - - -class TestBase(Base): - """init test base class""" - - def init_collection(self, name=None, schema=None): - collection_name = cf.gen_unique_str("test") if name is None else name - self.collection_name = collection_name - self.collection_object_list.append(collection_name) - if schema is None: - schema = cf.gen_default_schema(collection_name=collection_name) - # create collection - res = self.collection_service.create_collection(collection_name=collection_name, schema=schema) - - log.info(f"create collection name: {collection_name} with schema: {schema}") - return collection_name, schema - - - - diff --git a/tests/restful_client/base/collection_service.py b/tests/restful_client/base/collection_service.py deleted file mode 100644 index 292464f451..0000000000 --- a/tests/restful_client/base/collection_service.py +++ /dev/null @@ -1,96 +0,0 @@ -from api.collection import Collection -from utils.util_log import test_log as log -from models import milvus - - -TIMEOUT = 30 - - -class CollectionService: - - def __init__(self, endpoint=None, timeout=None): - if timeout is None: - timeout = TIMEOUT - if endpoint is None: - endpoint = "http://localhost:9091/api/v1" - self._collection = Collection(endpoint=endpoint) - - def create_collection(self, collection_name, consistency_level=1, schema=None, shards_num=2): - payload = { - "collection_name": collection_name, - "consistency_level": consistency_level, - "schema": schema, - "shards_num": shards_num - } - log.info(f"payload: {payload}") - # payload = milvus.CreateCollectionRequest(collection_name=collection_name, - # consistency_level=consistency_level, - # schema=schema, - # shards_num=shards_num) - # payload = payload.dict() - rsp = self._collection.create_collection(payload) - return rsp - - def has_collection(self, collection_name=None, time_stamp=0): - payload = { - "collection_name": collection_name, - "time_stamp": time_stamp - } - # payload = milvus.HasCollectionRequest(collection_name=collection_name, time_stamp=time_stamp) - # payload = payload.dict() - return self._collection.has_collection(payload) - - def drop_collection(self, collection_name): - payload = { - "collection_name": collection_name - } - # payload = milvus.DropCollectionRequest(collection_name=collection_name) - # payload = payload.dict() - return self._collection.drop_collection(payload) - - def describe_collection(self, collection_name, collection_id=None, time_stamp=0): - payload = { - "collection_name": collection_name, - "collection_id": collection_id, - "time_stamp": time_stamp - } - # payload = milvus.DescribeCollectionRequest(collection_name=collection_name, - # collectionID=collection_id, - # time_stamp=time_stamp) - # payload = payload.dict() - return self._collection.describe_collection(payload) - - def load_collection(self, collection_name, replica_number=1): - payload = { - "collection_name": collection_name, - "replica_number": replica_number - } - # payload = milvus.LoadCollectionRequest(collection_name=collection_name, replica_number=replica_number) - # payload = payload.dict() - return self._collection.load_collection(payload) - - def release_collection(self, collection_name): - payload = { - "collection_name": collection_name - } - # payload = milvus.ReleaseCollectionRequest(collection_name=collection_name) - # payload = payload.dict() - return self._collection.release_collection(payload) - - def get_collection_statistics(self, collection_name): - payload = { - "collection_name": collection_name - } - # payload = milvus.GetCollectionStatisticsRequest(collection_name=collection_name) - # payload = payload.dict() - return self._collection.get_collection_statistics(payload) - - def show_collections(self, collection_names=None, type=0): - - payload = { - "collection_names": collection_names, - "type": type - } - # payload = milvus.ShowCollectionsRequest(collection_names=collection_names, type=type) - # payload = payload.dict() - return self._collection.show_collections(payload) diff --git a/tests/restful_client/base/credential_service.py b/tests/restful_client/base/credential_service.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/restful_client/base/entity_service.py b/tests/restful_client/base/entity_service.py deleted file mode 100644 index cf5464ecc6..0000000000 --- a/tests/restful_client/base/entity_service.py +++ /dev/null @@ -1,182 +0,0 @@ -from api.entity import Entity -from common import common_type as ct -from utils.util_log import test_log as log -from models import common, schema, milvus, server - -TIMEOUT = 30 - - -class EntityService: - - def __init__(self, endpoint=None, timeout=None): - if timeout is None: - timeout = TIMEOUT - if endpoint is None: - endpoint = "http://localhost:9091/api/v1" - self._entity = Entity(endpoint=endpoint) - - def calc_distance(self, base=None, op_left=None, op_right=None, params=None): - payload = { - "base": base, - "op_left": op_left, - "op_right": op_right, - "params": params - } - # payload = milvus.CalcDistanceRequest(base=base, op_left=op_left, op_right=op_right, params=params) - # payload = payload.dict() - return self._entity.calc_distance(payload) - - def delete(self, base=None, collection_name=None, db_name=None, expr=None, hash_keys=None, partition_name=None): - payload = { - "base": base, - "collection_name": collection_name, - "db_name": db_name, - "expr": expr, - "hash_keys": hash_keys, - "partition_name": partition_name - } - # payload = server.DeleteRequest(base=base, - # collection_name=collection_name, - # db_name=db_name, - # expr=expr, - # hash_keys=hash_keys, - # partition_name=partition_name) - # payload = payload.dict() - return self._entity.delete(payload) - - def insert(self, base=None, collection_name=None, db_name=None, fields_data=None, hash_keys=None, num_rows=None, - partition_name=None, check_task=True): - payload = { - "base": base, - "collection_name": collection_name, - "db_name": db_name, - "fields_data": fields_data, - "hash_keys": hash_keys, - "num_rows": num_rows, - "partition_name": partition_name - } - # payload = milvus.InsertRequest(base=base, - # collection_name=collection_name, - # db_name=db_name, - # fields_data=fields_data, - # hash_keys=hash_keys, - # num_rows=num_rows, - # partition_name=partition_name) - # payload = payload.dict() - rsp = self._entity.insert(payload) - if check_task: - assert rsp["status"] == {} - assert rsp["insert_cnt"] == num_rows - return rsp - - def flush(self, base=None, collection_names=None, db_name=None, check_task=True): - payload = { - "base": base, - "collection_names": collection_names, - "db_name": db_name - } - # payload = server.FlushRequest(base=base, - # collection_names=collection_names, - # db_name=db_name) - # payload = payload.dict() - rsp = self._entity.flush(payload) - if check_task: - assert rsp["status"] == {} - - def get_persistent_segment_info(self, base=None, collection_name=None, db_name=None): - payload = { - "base": base, - "collection_name": collection_name, - "db_name": db_name - } - # payload = server.GetPersistentSegmentInfoRequest(base=base, - # collection_name=collection_name, - # db_name=db_name) - # payload = payload.dict() - return self._entity.get_persistent_segment_info(payload) - - def get_flush_state(self, segment_ids=None): - payload = { - "segment_ids": segment_ids - } - # payload = server.GetFlushStateRequest(segment_ids=segment_ids) - # payload = payload.dict() - return self._entity.get_flush_state(payload) - - def query(self, base=None, collection_name=None, db_name=None, expr=None, - guarantee_timestamp=None, output_fields=None, partition_names=None, travel_timestamp=None, - check_task=True): - payload = { - "base": base, - "collection_name": collection_name, - "db_name": db_name, - "expr": expr, - "guarantee_timestamp": guarantee_timestamp, - "output_fields": output_fields, - "partition_names": partition_names, - "travel_timestamp": travel_timestamp - } - # - # payload = server.QueryRequest(base=base, collection_name=collection_name, db_name=db_name, expr=expr, - # guarantee_timestamp=guarantee_timestamp, output_fields=output_fields, - # partition_names=partition_names, travel_timestamp=travel_timestamp) - # payload = payload.dict() - rsp = self._entity.query(payload) - if check_task: - fields_data = rsp["fields_data"] - for field_data in fields_data: - if field_data["field_name"] in expr: - data = field_data["Field"]["Scalars"]["Data"]["LongData"]["data"] - for d in data: - s = expr.replace(field_data["field_name"], str(d)) - assert eval(s) is True - return rsp - - def get_query_segment_info(self, base=None, collection_name=None, db_name=None): - payload = { - "base": base, - "collection_name": collection_name, - "db_name": db_name - } - # payload = server.GetQuerySegmentInfoRequest(base=base, - # collection_name=collection_name, - # db_name=db_name) - # payload = payload.dict() - return self._entity.get_query_segment_info(payload) - - def search(self, base=None, collection_name=None, vectors=None, db_name=None, dsl=None, - output_fields=None, dsl_type=1, - guarantee_timestamp=None, partition_names=None, placeholder_group=None, - search_params=None, travel_timestamp=None, check_task=True): - payload = { - "base": base, - "collection_name": collection_name, - "output_fields": output_fields, - "vectors": vectors, - "db_name": db_name, - "dsl": dsl, - "dsl_type": dsl_type, - "guarantee_timestamp": guarantee_timestamp, - "partition_names": partition_names, - "placeholder_group": placeholder_group, - "search_params": search_params, - "travel_timestamp": travel_timestamp - } - # payload = server.SearchRequest(base=base, collection_name=collection_name, db_name=db_name, dsl=dsl, - # dsl_type=dsl_type, guarantee_timestamp=guarantee_timestamp, - # partition_names=partition_names, placeholder_group=placeholder_group, - # search_params=search_params, travel_timestamp=travel_timestamp) - # payload = payload.dict() - rsp = self._entity.search(payload) - if check_task: - assert rsp["status"] == {} - assert rsp["results"]["num_queries"] == len(vectors) - assert len(rsp["results"]["ids"]["IdField"]["IntId"]["data"]) == sum(rsp["results"]["topks"]) - return rsp - - - - - - - diff --git a/tests/restful_client/base/error_code_message.py b/tests/restful_client/base/error_code_message.py new file mode 100644 index 0000000000..0fffd4fbcb --- /dev/null +++ b/tests/restful_client/base/error_code_message.py @@ -0,0 +1,41 @@ +from enum import Enum + + +class BaseError(Enum): + pass + + +class VectorInsertError(BaseError): + pass + + +class VectorSearchError(BaseError): + pass + + +class VectorGetError(BaseError): + pass + + +class VectorQueryError(BaseError): + pass + + +class VectorDeleteError(BaseError): + pass + + +class CollectionListError(BaseError): + pass + + +class CollectionCreateError(BaseError): + pass + + +class CollectionDropError(BaseError): + pass + + +class CollectionDescribeError(BaseError): + pass diff --git a/tests/restful_client/base/import_service.py b/tests/restful_client/base/import_service.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/restful_client/base/index_service.py b/tests/restful_client/base/index_service.py deleted file mode 100644 index ac48f35541..0000000000 --- a/tests/restful_client/base/index_service.py +++ /dev/null @@ -1,57 +0,0 @@ -from api.index import Index -from models import common, schema, milvus, server - -TIMEOUT = 30 - - -class IndexService: - - def __init__(self, endpoint=None, timeout=None): - if timeout is None: - timeout = TIMEOUT - if endpoint is None: - endpoint = "http://localhost:9091/api/v1" - self._index = Index(endpoint=endpoint) - - def drop_index(self, base, collection_name, db_name, field_name, index_name): - payload = server.DropIndexRequest(base=base, collection_name=collection_name, - db_name=db_name, field_name=field_name, index_name=index_name) - payload = payload.dict() - return self._index.drop_index(payload) - - def describe_index(self, base, collection_name, db_name, field_name, index_name): - payload = server.DescribeIndexRequest(base=base, collection_name=collection_name, - db_name=db_name, field_name=field_name, index_name=index_name) - payload = payload.dict() - return self._index.describe_index(payload) - - def create_index(self, base=None, collection_name=None, db_name=None, extra_params=None, - field_name=None, index_name=None): - payload = { - "base": base, - "collection_name": collection_name, - "db_name": db_name, - "extra_params": extra_params, - "field_name": field_name, - "index_name": index_name - } - # payload = server.CreateIndexRequest(base=base, collection_name=collection_name, db_name=db_name, - # extra_params=extra_params, field_name=field_name, index_name=index_name) - # payload = payload.dict() - return self._index.create_index(payload) - - def get_index_build_progress(self, base, collection_name, db_name, field_name, index_name): - payload = server.GetIndexBuildProgressRequest(base=base, collection_name=collection_name, - db_name=db_name, field_name=field_name, index_name=index_name) - payload = payload.dict() - return self._index.get_index_build_progress(payload) - - def get_index_state(self, base, collection_name, db_name, field_name, index_name): - payload = server.GetIndexStateRequest(base=base, collection_name=collection_name, - db_name=db_name, field_name=field_name, index_name=index_name) - payload = payload.dict() - return self._index.get_index_state(payload) - - - - diff --git a/tests/restful_client/base/metrics_service.py b/tests/restful_client/base/metrics_service.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/restful_client/base/ops_service.py b/tests/restful_client/base/ops_service.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/restful_client/base/partition_service.py b/tests/restful_client/base/partition_service.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/restful_client/base/testbase.py b/tests/restful_client/base/testbase.py new file mode 100644 index 0000000000..758f308848 --- /dev/null +++ b/tests/restful_client/base/testbase.py @@ -0,0 +1,114 @@ +import json +import sys + +import pytest +import time +from pymilvus import connections, db +from utils.util_log import test_log as logger +from api.milvus import VectorClient, CollectionClient +from utils.utils import get_data_by_payload + + +def get_config(): + pass + + +class Base: + name = None + host = None + port = None + url = None + api_key = None + username = None + password = None + invalid_api_key = None + vector_client = None + collection_client = None + + +class TestBase(Base): + + def teardown_method(self): + self.collection_client.api_key = self.api_key + all_collections = self.collection_client.collection_list()['data'] + if self.name in all_collections: + logger.info(f"collection {self.name} exist, drop it") + payload = { + "collectionName": self.name, + } + try: + rsp = self.collection_client.collection_drop(payload) + except Exception as e: + logger.error(e) + + @pytest.fixture(scope="function", autouse=True) + def init_client(self, host, port, username, password): + self.host = host + self.port = port + self.url = f"{host}:{port}/v1" + self.username = username + self.password = password + self.api_key = f"{self.username}:{self.password}" + self.invalid_api_key = "invalid_token" + self.vector_client = VectorClient(self.url, self.api_key) + self.collection_client = CollectionClient(self.url, self.api_key) + + def init_collection(self, collection_name, pk_field="id", metric_type="L2", dim=128, nb=100, batch_size=1000): + # create collection + schema_payload = { + "collectionName": collection_name, + "dimension": dim, + "metricType": metric_type, + "description": "test collection", + "primaryField": pk_field, + "vectorField": "vector", + } + rsp = self.collection_client.collection_create(schema_payload) + assert rsp['code'] == 200 + self.wait_collection_load_completed(collection_name) + batch_size = batch_size + batch = nb // batch_size + # in case of nb < batch_size + if batch == 0: + batch = 1 + batch_size = nb + data = [] + for i in range(batch): + nb = batch_size + data = get_data_by_payload(schema_payload, nb) + payload = { + "collectionName": collection_name, + "data": data + } + body_size = sys.getsizeof(json.dumps(payload)) + logger.info(f"body size: {body_size / 1024 / 1024} MB") + rsp = self.vector_client.vector_insert(payload) + assert rsp['code'] == 200 + return schema_payload, data + + def wait_collection_load_completed(self, name): + t0 = time.time() + timeout = 60 + while True and time.time() - t0 < timeout: + rsp = self.collection_client.collection_describe(name) + if "data" in rsp and "load" in rsp["data"] and rsp["data"]["load"] == "LoadStateLoaded": + break + else: + time.sleep(5) + + def create_database(self, db_name="default"): + connections.connect(host=self.host, port=self.port) + all_db = db.list_database() + logger.info(f"all database: {all_db}") + if db_name not in all_db: + logger.info(f"create database: {db_name}") + try: + db.create_database(db_name=db_name) + except Exception as e: + logger.error(e) + + def update_database(self, db_name="default"): + self.create_database(db_name=db_name) + self.collection_client.db_name = db_name + self.vector_client.db_name = db_name + diff --git a/tests/restful_client/check/func_check.py b/tests/restful_client/check/func_check.py deleted file mode 100644 index 13a1711e07..0000000000 --- a/tests/restful_client/check/func_check.py +++ /dev/null @@ -1,27 +0,0 @@ - -class CheckTasks: - """ The name of the method used to check the result """ - check_nothing = "check_nothing" - err_res = "error_response" - ccr = "check_connection_result" - check_collection_property = "check_collection_property" - check_partition_property = "check_partition_property" - check_search_results = "check_search_results" - check_query_results = "check_query_results" - check_query_empty = "check_query_empty" # verify that query result is empty - check_query_not_empty = "check_query_not_empty" - check_distance = "check_distance" - check_delete_compact = "check_delete_compact" - check_merge_compact = "check_merge_compact" - check_role_property = "check_role_property" - check_permission_deny = "check_permission_deny" - check_value_equal = "check_value_equal" - - -class ResponseChecker: - - def __init__(self, check_task, check_items): - self.check_task = check_task - self.check_items = check_items - - diff --git a/tests/restful_client/check/param_check.py b/tests/restful_client/check/param_check.py deleted file mode 100644 index d4938f950c..0000000000 --- a/tests/restful_client/check/param_check.py +++ /dev/null @@ -1,21 +0,0 @@ -from utils.util_log import test_log as log - - -def ip_check(ip): - if ip == "localhost": - return True - - if not isinstance(ip, str): - log.error("[IP_CHECK] IP(%s) is not a string." % ip) - return False - - return True - - -def number_check(num): - if str(num).isdigit(): - return True - - else: - log.error("[NUMBER_CHECK] Number(%s) is not a numbers." % num) - return False diff --git a/tests/restful_client/common/common_func.py b/tests/restful_client/common/common_func.py deleted file mode 100644 index 5e95fa174b..0000000000 --- a/tests/restful_client/common/common_func.py +++ /dev/null @@ -1,292 +0,0 @@ -import json -import os -import random -import string -import numpy as np -from enum import Enum -from common import common_type as ct -from utils.util_log import test_log as log - - -class ParamInfo: - def __init__(self): - self.param_host = "" - self.param_port = "" - - def prepare_param_info(self, host, http_port): - self.param_host = host - self.param_port = http_port - - -param_info = ParamInfo() - - -class DataType(Enum): - Bool: 1 - Int8: 2 - Int16: 3 - Int32: 4 - Int64: 5 - Float: 10 - Double: 11 - String: 20 - VarChar: 21 - BinaryVector: 100 - FloatVector: 101 - - -def gen_unique_str(str_value=None): - prefix = "".join(random.choice(string.ascii_letters + string.digits) for _ in range(8)) - return "test_" + prefix if str_value is None else str_value + "_" + prefix - - -def gen_field(name=ct.default_bool_field_name, description=ct.default_desc, type_params=None, index_params=None, - data_type="Int64", is_primary_key=False, auto_id=False, dim=128, max_length=256): - data_type_map = { - "Bool": 1, - "Int8": 2, - "Int16": 3, - "Int32": 4, - "Int64": 5, - "Float": 10, - "Double": 11, - "String": 20, - "VarChar": 21, - "BinaryVector": 100, - "FloatVector": 101, - } - if data_type == "Int64": - is_primary_key = True - auto_id = True - if type_params is None: - type_params = [] - if index_params is None: - index_params = [] - if data_type in ["FloatVector", "BinaryVector"]: - type_params = [{"key": "dim", "value": str(dim)}] - if data_type in ["String", "VarChar"]: - type_params = [{"key": "max_length", "value": str(dim)}] - return { - "name": name, - "description": description, - "data_type": data_type_map.get(data_type, 0), - "type_params": type_params, - "index_params": index_params, - "is_primary_key": is_primary_key, - "auto_id": auto_id, - } - - -def gen_schema(name, fields, description=ct.default_desc, auto_id=False): - return { - "name": name, - "description": description, - "auto_id": auto_id, - "fields": fields, - } - - -def gen_default_schema(data_types=None, dim=ct.default_dim, collection_name=None): - if data_types is None: - data_types = ["Int64", "Float", "VarChar", "FloatVector"] - fields = [] - for data_type in data_types: - if data_type in ["FloatVector", "BinaryVector"]: - fields.append(gen_field(name=data_type, data_type=data_type, type_params=[{"key": "dim", "value": dim}])) - else: - fields.append(gen_field(name=data_type, data_type=data_type)) - return { - "autoID": True, - "fields": fields, - "description": ct.default_desc, - "name": collection_name, - } - - -def gen_fields_data(schema=None, nb=ct.default_nb,): - if schema is None: - schema = gen_default_schema() - fields = schema["fields"] - fields_data = [] - for field in fields: - if field["data_type"] == 1: - fields_data.append([random.choice([True, False]) for i in range(nb)]) - elif field["data_type"] == 2: - fields_data.append([i for i in range(nb)]) - elif field["data_type"] == 3: - fields_data.append([i for i in range(nb)]) - elif field["data_type"] == 4: - fields_data.append([i for i in range(nb)]) - elif field["data_type"] == 5: - fields_data.append([i for i in range(nb)]) - elif field["data_type"] == 10: - fields_data.append([np.float64(i) for i in range(nb)]) # json not support float32 - elif field["data_type"] == 11: - fields_data.append([np.float64(i) for i in range(nb)]) - elif field["data_type"] == 20: - fields_data.append([gen_unique_str((str(i))) for i in range(nb)]) - elif field["data_type"] == 21: - fields_data.append([gen_unique_str(str(i)) for i in range(nb)]) - elif field["data_type"] == 100: - dim = ct.default_dim - for k, v in field["type_params"]: - if k == "dim": - dim = int(v) - break - fields_data.append(gen_binary_vectors(nb, dim)) - elif field["data_type"] == 101: - dim = ct.default_dim - for k, v in field["type_params"]: - if k == "dim": - dim = int(v) - break - fields_data.append(gen_float_vectors(nb, dim)) - else: - log.error("Unknown data type.") - fields_data_body = [] - for i, field in enumerate(fields): - fields_data_body.append({ - "field_name": field["name"], - "type": field["data_type"], - "field": fields_data[i], - }) - return fields_data_body - - -def get_vector_field(schema): - for field in schema["fields"]: - if field["data_type"] in [100, 101]: - return field["name"] - return None - - -def get_varchar_field(schema): - for field in schema["fields"]: - if field["data_type"] == 21: - return field["name"] - return None - - -def gen_vectors(nq=None, schema=None): - if nq is None: - nq = ct.default_nq - dim = ct.default_dim - data_type = 101 - for field in schema["fields"]: - if field["data_type"] in [100, 101]: - dim = ct.default_dim - data_type = field["data_type"] - for k, v in field["type_params"]: - if k == "dim": - dim = int(v) - break - if data_type == 100: - return gen_binary_vectors(nq, dim) - if data_type == 101: - return gen_float_vectors(nq, dim) - - -def gen_float_vectors(nb, dim): - return [[np.float64(random.uniform(-1.0, 1.0)) for _ in range(dim)] for _ in range(nb)] # json not support float32 - - -def gen_binary_vectors(nb, dim): - raw_vectors = [] - binary_vectors = [] - for _ in range(nb): - raw_vector = [random.randint(0, 1) for _ in range(dim)] - raw_vectors.append(raw_vector) - # packs a binary-valued array into bits in a unit8 array, and bytes array_of_ints - binary_vectors.append(bytes(np.packbits(raw_vector, axis=-1).tolist())) - return binary_vectors - - -def gen_index_params(index_type=None): - if index_type is None: - index_params = ct.default_index_params - else: - index_params = ct.all_index_params_map[index_type] - extra_params = [] - for k, v in index_params.items(): - item = {"key": k, "value": json.dumps(v) if isinstance(v, dict) else str(v)} - extra_params.append(item) - return extra_params - -def gen_search_param_by_index_type(index_type, metric_type="L2"): - search_params = [] - if index_type in ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_PQ"]: - for nprobe in [10]: - ivf_search_params = {"metric_type": metric_type, "params": {"nprobe": nprobe}} - search_params.append(ivf_search_params) - elif index_type in ["BIN_FLAT", "BIN_IVF_FLAT"]: - for nprobe in [10]: - bin_search_params = {"metric_type": "HAMMING", "params": {"nprobe": nprobe}} - search_params.append(bin_search_params) - elif index_type in ["HNSW"]: - for ef in [64]: - hnsw_search_param = {"metric_type": metric_type, "params": {"ef": ef}} - search_params.append(hnsw_search_param) - elif index_type == "ANNOY": - for search_k in [1000]: - annoy_search_param = {"metric_type": metric_type, "params": {"search_k": search_k}} - search_params.append(annoy_search_param) - else: - log.info("Invalid index_type.") - raise Exception("Invalid index_type.") - return search_params - - -def gen_search_params(index_type=None, anns_field=ct.default_float_vec_field_name, - topk=ct.default_top_k): - if index_type is None: - search_params = gen_search_param_by_index_type(ct.default_index_type)[0] - else: - search_params = gen_search_param_by_index_type(index_type)[0] - extra_params = [] - for k, v in search_params.items(): - item = {"key": k, "value": json.dumps(v) if isinstance(v, dict) else str(v)} - extra_params.append(item) - extra_params.append({"key": "anns_field", "value": anns_field}) - extra_params.append({"key": "topk", "value": str(topk)}) - return extra_params - - - - -def gen_search_vectors(dim, nb, is_binary=False): - if is_binary: - return gen_binary_vectors(nb, dim) - return gen_float_vectors(nb, dim) - - -def modify_file(file_path_list, is_modify=False, input_content=""): - """ - file_path_list : file list -> list[] - is_modify : does the file need to be reset - input_content :the content that need to insert to the file - """ - if not isinstance(file_path_list, list): - log.error("[modify_file] file is not a list.") - - for file_path in file_path_list: - folder_path, file_name = os.path.split(file_path) - if not os.path.isdir(folder_path): - log.debug("[modify_file] folder(%s) is not exist." % folder_path) - os.makedirs(folder_path) - - if not os.path.isfile(file_path): - log.error("[modify_file] file(%s) is not exist." % file_path) - else: - if is_modify is True: - log.debug("[modify_file] start modifying file(%s)..." % file_path) - with open(file_path, "r+") as f: - f.seek(0) - f.truncate() - f.write(input_content) - f.close() - log.info("[modify_file] file(%s) modification is complete." % file_path_list) - - -if __name__ == '__main__': - a = gen_binary_vectors(10, 128) - print(a) diff --git a/tests/restful_client/common/common_type.py b/tests/restful_client/common/common_type.py deleted file mode 100644 index 498f8bfc61..0000000000 --- a/tests/restful_client/common/common_type.py +++ /dev/null @@ -1,80 +0,0 @@ -""" Initialized parameters """ -port = 19530 -epsilon = 0.000001 -namespace = "milvus" -default_flush_interval = 1 -big_flush_interval = 1000 -default_drop_interval = 3 -default_dim = 128 -default_nb = 3000 -default_nb_medium = 5000 -default_top_k = 10 -default_nq = 2 -default_limit = 10 -default_search_params = {"metric_type": "L2", "params": {"nprobe": 10}} -default_search_ip_params = {"metric_type": "IP", "params": {"nprobe": 10}} -default_search_binary_params = {"metric_type": "JACCARD", "params": {"nprobe": 10}} -default_index_type = "HNSW" -default_index_params = {"index_type": "HNSW", "params": {"M": 48, "efConstruction": 500}, "metric_type": "L2"} -default_varchar_index = {} -default_binary_index = {"index_type": "BIN_IVF_FLAT", "params": {"nlist": 128}, "metric_type": "JACCARD"} -default_diskann_index = {"index_type": "DISKANN", "metric_type": "L2", "params": {}} -default_diskann_search_params = {"metric_type": "L2", "params": {"search_list": 30}} -max_top_k = 16384 -max_partition_num = 4096 # 256 -default_segment_row_limit = 1000 -default_server_segment_row_limit = 1024 * 512 -default_alias = "default" -default_user = "root" -default_password = "Milvus" -default_bool_field_name = "Bool" -default_int8_field_name = "Int8" -default_int16_field_name = "Int16" -default_int32_field_name = "Int32" -default_int64_field_name = "Int64" -default_float_field_name = "Float" -default_double_field_name = "Double" -default_string_field_name = "Varchar" -default_float_vec_field_name = "FloatVector" -another_float_vec_field_name = "FloatVector1" -default_binary_vec_field_name = "BinaryVector" -default_partition_name = "_default" -default_tag = "1970_01_01" -row_count = "row_count" -default_length = 65535 -default_desc = "" -default_collection_desc = "default collection" -default_index_name = "default_index_name" -default_binary_desc = "default binary collection" -collection_desc = "collection" -int_field_desc = "int64 type field" -float_field_desc = "float type field" -float_vec_field_desc = "float vector type field" -binary_vec_field_desc = "binary vector type field" -max_dim = 32768 -min_dim = 1 -gracefulTime = 1 -default_nlist = 128 -compact_segment_num_threshold = 4 -compact_delta_ratio_reciprocal = 5 # compact_delta_binlog_ratio is 0.2 -compact_retention_duration = 40 # compaction travel time retention range 20s -max_compaction_interval = 60 # the max time interval (s) from the last compaction -max_field_num = 256 # Maximum number of fields in a collection - -default_dsl = f"{default_int64_field_name} in [2,4,6,8]" -default_expr = f"{default_int64_field_name} in [2,4,6,8]" - -metric_types = [] -all_index_types = ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_PQ", "HNSW", "ANNOY", "DISKANN", "BIN_FLAT", "BIN_IVF_FLAT"] -all_index_params_map = {"FLAT": {"index_type": "FLAT", "params": {"nlist": 128}, "metric_type": "L2"}, - "IVF_FLAT": {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"}, - "IVF_SQ8": {"index_type": "IVF_SQ8", "params": {"nlist": 128}, "metric_type": "L2"}, - "IVF_PQ": {"index_type": "IVF_PQ", "params": {"nlist": 128, "m": 16, "nbits": 8}, - "metric_type": "L2"}, - "HNSW": {"index_type": "HNSW", "params": {"M": 48, "efConstruction": 500}, "metric_type": "L2"}, - "ANNOY": {"index_type": "ANNOY", "params": {"n_trees": 50}, "metric_type": "L2"}, - "DISKANN": {"index_type": "DISKANN", "params": {}, "metric_type": "L2"}, - "BIN_FLAT": {"index_type": "BIN_FLAT", "params": {"nlist": 128}, "metric_type": "JACCARD"}, - "BIN_IVF_FLAT": {"index_type": "BIN_IVF_FLAT", "params": {"nlist": 128}, - "metric_type": "JACCARD"} - } diff --git a/tests/restful_client/conftest.py b/tests/restful_client/conftest.py index 08f2511e92..0393280c93 100644 --- a/tests/restful_client/conftest.py +++ b/tests/restful_client/conftest.py @@ -1,15 +1,12 @@ import pytest -import common.common_func as cf -from check.param_check import ip_check, number_check -from config.log_config import log_config -from utils.util_log import test_log as log -from common.common_func import param_info +import yaml def pytest_addoption(parser): - parser.addoption("--host", action="store", default="127.0.0.1", help="Milvus host") - parser.addoption("--port", action="store", default="9091", help="Milvus http port") - parser.addoption('--clean_log', action='store_true', default=False, help="clean log before testing") + parser.addoption("--host", action="store", default="127.0.0.1", help="host") + parser.addoption("--port", action="store", default="19530", help="port") + parser.addoption("--username", action="store", default="root", help="email") + parser.addoption("--password", action="store", default="Milvus", help="password") @pytest.fixture @@ -23,27 +20,11 @@ def port(request): @pytest.fixture -def clean_log(request): - return request.config.getoption("--clean_log") +def username(request): + return request.config.getoption("--username") -@pytest.fixture(scope="session", autouse=True) -def initialize_env(request): - """ clean log before testing """ - host = request.config.getoption("--host") - port = request.config.getoption("--port") - clean_log = request.config.getoption("--clean_log") +@pytest.fixture +def password(request): + return request.config.getoption("--password") - - """ params check """ - assert ip_check(host) and number_check(port) - - """ modify log files """ - file_path_list = [log_config.log_debug, log_config.log_info, log_config.log_err] - if log_config.log_worker != "": - file_path_list.append(log_config.log_worker) - cf.modify_file(file_path_list=file_path_list, is_modify=clean_log) - - log.info("#" * 80) - log.info("[initialize_milvus] Log cleaned up, start testing...") - param_info.prepare_param_info(host, port) diff --git a/tests/restful_client/models/__init__.py b/tests/restful_client/models/__init__.py deleted file mode 100644 index 8d4899ce64..0000000000 --- a/tests/restful_client/models/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# generated by datamodel-codegen: -# filename: openapi.json -# timestamp: 2022-12-08T02:46:08+00:00 diff --git a/tests/restful_client/models/common.py b/tests/restful_client/models/common.py deleted file mode 100644 index a4c33e484f..0000000000 --- a/tests/restful_client/models/common.py +++ /dev/null @@ -1,28 +0,0 @@ -# generated by datamodel-codegen: -# filename: openapi.json -# timestamp: 2022-12-08T02:46:08+00:00 - -from __future__ import annotations - -from typing import List, Optional - -from pydantic import BaseModel, Field - - -class KeyDataPair(BaseModel): - data: Optional[List[int]] = None - key: Optional[str] = None - - -class KeyValuePair(BaseModel): - key: Optional[str] = Field(None, example='dim') - value: Optional[str] = Field(None, example='128') - - -class MsgBase(BaseModel): - msg_type: Optional[int] = Field(None, description='Not useful for now') - - -class Status(BaseModel): - error_code: Optional[int] = None - reason: Optional[str] = None diff --git a/tests/restful_client/models/milvus.py b/tests/restful_client/models/milvus.py deleted file mode 100644 index 25613b4258..0000000000 --- a/tests/restful_client/models/milvus.py +++ /dev/null @@ -1,138 +0,0 @@ -# generated by datamodel-codegen: -# filename: openapi.json -# timestamp: 2022-12-08T02:46:08+00:00 - -from __future__ import annotations - -from typing import List, Optional - -from pydantic import BaseModel, Field - -from models import common, schema - - -class DescribeCollectionRequest(BaseModel): - collection_name: Optional[str] = None - collectionID: Optional[int] = Field( - None, description='The collection ID you want to describe' - ) - time_stamp: Optional[int] = Field( - None, - description='If time_stamp is not zero, will describe collection success when time_stamp >= created collection timestamp, otherwise will throw error.', - ) - - -class DropCollectionRequest(BaseModel): - collection_name: Optional[str] = Field( - None, description='The unique collection name in milvus.(Required)' - ) - - -class FieldData(BaseModel): - field: Optional[List] = None - field_id: Optional[int] = None - field_name: Optional[str] = None - type: Optional[int] = Field( - None, - description='0: "None",\n1: "Bool",\n2: "Int8",\n3: "Int16",\n4: "Int32",\n5: "Int64",\n10: "Float",\n11: "Double",\n20: "String",\n21: "VarChar",\n100: "BinaryVector",\n101: "FloatVector",', - ) - - -class GetCollectionStatisticsRequest(BaseModel): - collection_name: Optional[str] = Field( - None, description='The collection name you want get statistics' - ) - - -class HasCollectionRequest(BaseModel): - collection_name: Optional[str] = Field( - None, description='The unique collection name in milvus.(Required)' - ) - time_stamp: Optional[int] = Field( - None, - description='If time_stamp is not zero, will return true when time_stamp >= created collection timestamp, otherwise will return false.', - ) - - -class InsertRequest(BaseModel): - base: Optional[common.MsgBase] = None - collection_name: Optional[str] = None - db_name: Optional[str] = None - fields_data: Optional[List[FieldData]] = None - hash_keys: Optional[List[int]] = None - num_rows: Optional[int] = None - partition_name: Optional[str] = None - - -class LoadCollectionRequest(BaseModel): - collection_name: Optional[str] = Field( - None, description='The collection name you want to load' - ) - replica_number: Optional[int] = Field( - None, description='The replica number to load, default by 1' - ) - - -class ReleaseCollectionRequest(BaseModel): - collection_name: Optional[str] = Field( - None, description='The collection name you want to release' - ) - - -class ShowCollectionsRequest(BaseModel): - collection_names: Optional[List[str]] = Field( - None, - description="When type is InMemory, will return these collection's inMemory_percentages.(Optional)", - ) - type: Optional[int] = Field( - None, - description='Decide return Loaded collections or All collections(Optional)', - ) - - -class VectorIDs(BaseModel): - collection_name: Optional[str] = None - field_name: Optional[str] = None - id_array: Optional[List[int]] = None - partition_names: Optional[List[str]] = None - - -class VectorsArray(BaseModel): - binary_vectors: Optional[List[int]] = Field( - None, - description='Vectors is an array of binary vector divided by given dim. Disabled when IDs is set', - ) - dim: Optional[int] = Field( - None, description='Dim of vectors or binary_vectors, not needed when use ids' - ) - ids: Optional[VectorIDs] = None - vectors: Optional[List[float]] = Field( - None, - description='Vectors is an array of vector divided by given dim. Disabled when ids or binary_vectors is set', - ) - - -class CalcDistanceRequest(BaseModel): - base: Optional[common.MsgBase] = None - op_left: Optional[VectorsArray] = None - op_right: Optional[VectorsArray] = None - params: Optional[List[common.KeyValuePair]] = None - - -class CreateCollectionRequest(BaseModel): - collection_name: str = Field( - ..., - description='The unique collection name in milvus.(Required)', - example='book', - ) - consistency_level: int = Field( - ..., - description='The consistency level that the collection used, modification is not supported now.\n"Strong": 0,\n"Session": 1,\n"Bounded": 2,\n"Eventually": 3,\n"Customized": 4,', - example=1, - ) - schema_: schema.CollectionSchema = Field(..., alias='schema') - shards_num: Optional[int] = Field( - None, - description='Once set, no modification is allowed (Optional)\nhttps://github.com/milvus-io/milvus/issues/6690', - example=1, - ) diff --git a/tests/restful_client/models/schema.py b/tests/restful_client/models/schema.py deleted file mode 100644 index c66c8ac642..0000000000 --- a/tests/restful_client/models/schema.py +++ /dev/null @@ -1,72 +0,0 @@ -# generated by datamodel-codegen: -# filename: openapi.json -# timestamp: 2022-12-08T02:46:08+00:00 - -from __future__ import annotations - -from typing import Any, List, Optional - -from pydantic import BaseModel, Field - -from models import common - - -class FieldData(BaseModel): - field: Optional[Any] = Field( - None, - description='Types that are assignable to Field:\n\t*FieldData_Scalars\n\t*FieldData_Vectors', - ) - field_id: Optional[int] = None - field_name: Optional[str] = None - type: Optional[int] = Field( - None, - description='0: "None",\n1: "Bool",\n2: "Int8",\n3: "Int16",\n4: "Int32",\n5: "Int64",\n10: "Float",\n11: "Double",\n20: "String",\n21: "VarChar",\n100: "BinaryVector",\n101: "FloatVector",', - ) - - -class FieldSchema(BaseModel): - autoID: Optional[bool] = None - data_type: int = Field( - ..., - description='0: "None",\n1: "Bool",\n2: "Int8",\n3: "Int16",\n4: "Int32",\n5: "Int64",\n10: "Float",\n11: "Double",\n20: "String",\n21: "VarChar",\n100: "BinaryVector",\n101: "FloatVector",', - example=101, - ) - description: Optional[str] = Field( - None, example='embedded vector of book introduction' - ) - fieldID: Optional[int] = None - index_params: Optional[List[common.KeyValuePair]] = None - is_primary_key: Optional[bool] = Field(None, example=False) - name: str = Field(..., example='book_intro') - type_params: Optional[List[common.KeyValuePair]] = None - - -class IDs(BaseModel): - idField: Optional[Any] = Field( - None, - description='Types that are assignable to IdField:\n\t*IDs_IntId\n\t*IDs_StrId', - ) - - -class LongArray(BaseModel): - data: Optional[List[int]] = None - - -class SearchResultData(BaseModel): - fields_data: Optional[List[FieldData]] = None - ids: Optional[IDs] = None - num_queries: Optional[int] = None - scores: Optional[List[float]] = None - top_k: Optional[int] = None - topks: Optional[List[int]] = None - - -class CollectionSchema(BaseModel): - autoID: Optional[bool] = Field( - None, - description='deprecated later, keep compatible with c++ part now', - example=False, - ) - description: Optional[str] = Field(None, example='Test book search') - fields: Optional[List[FieldSchema]] = None - name: str = Field(..., example='book') diff --git a/tests/restful_client/models/server.py b/tests/restful_client/models/server.py deleted file mode 100644 index 57f7d06fe2..0000000000 --- a/tests/restful_client/models/server.py +++ /dev/null @@ -1,587 +0,0 @@ -# generated by datamodel-codegen: -# filename: openapi.json -# timestamp: 2022-12-08T02:46:08+00:00 - -from __future__ import annotations - -from typing import Any, Dict, List, Optional - -from pydantic import BaseModel, Field - -from models import common, schema - - -class AlterAliasRequest(BaseModel): - alias: Optional[str] = None - base: Optional[common.MsgBase] = None - collection_name: Optional[str] = None - db_name: Optional[str] = None - - -class BoolResponse(BaseModel): - status: Optional[common.Status] = None - value: Optional[bool] = None - - -class CalcDistanceResults(BaseModel): - array: Optional[Any] = Field( - None, - description='num(op_left)*num(op_right) distance values, "HAMMIN" return integer distance\n\nTypes that are assignable to Array:\n\t*CalcDistanceResults_IntDist\n\t*CalcDistanceResults_FloatDist', - ) - status: Optional[common.Status] = None - - -class CompactionMergeInfo(BaseModel): - sources: Optional[List[int]] = None - target: Optional[int] = None - - -class CreateAliasRequest(BaseModel): - alias: Optional[str] = None - base: Optional[common.MsgBase] = None - collection_name: Optional[str] = None - db_name: Optional[str] = None - - -class CreateCredentialRequest(BaseModel): - base: Optional[common.MsgBase] = None - created_utc_timestamps: Optional[int] = Field(None, description='create time') - modified_utc_timestamps: Optional[int] = Field(None, description='modify time') - password: Optional[str] = Field(None, description='ciphertext password') - username: Optional[str] = Field(None, description='username') - - -class CreateIndexRequest(BaseModel): - base: Optional[common.MsgBase] = None - collection_name: Optional[str] = Field( - None, description='The particular collection name you want to create index.' - ) - db_name: Optional[str] = Field(None, description='Not useful for now') - extra_params: Optional[List[common.KeyValuePair]] = Field( - None, - description='Support keys: index_type,metric_type, params. Different index_type may has different params.', - ) - field_name: Optional[str] = Field( - None, description='The vector field name in this particular collection' - ) - index_name: Optional[str] = Field( - None, - description="Version before 2.0.2 doesn't contain index_name, we use default index name.", - ) - - -class CreatePartitionRequest(BaseModel): - base: Optional[common.MsgBase] = None - collection_name: Optional[str] = Field( - None, description='The collection name in milvus' - ) - db_name: Optional[str] = Field(None, description='Not useful for now') - partition_name: Optional[str] = Field( - None, description='The partition name you want to create.' - ) - - -class DeleteCredentialRequest(BaseModel): - base: Optional[common.MsgBase] = None - username: Optional[str] = Field(None, description='Not useful for now') - - -class DeleteRequest(BaseModel): - base: Optional[common.MsgBase] = None - collection_name: Optional[str] = None - db_name: Optional[str] = None - expr: Optional[str] = None - hash_keys: Optional[List[int]] = None - partition_name: Optional[str] = None - - -class DescribeIndexRequest(BaseModel): - base: Optional[common.MsgBase] = None - collection_name: Optional[str] = Field( - None, description='The particular collection name in Milvus' - ) - db_name: Optional[str] = Field(None, description='Not useful for now') - field_name: Optional[str] = Field( - None, description='The vector field name in this particular collection' - ) - index_name: Optional[str] = Field( - None, description='No need to set up for now @2021.06.30' - ) - - -class DropAliasRequest(BaseModel): - alias: Optional[str] = None - base: Optional[common.MsgBase] = None - db_name: Optional[str] = None - - -class DropIndexRequest(BaseModel): - base: Optional[common.MsgBase] = None - collection_name: Optional[str] = Field(None, description='must') - db_name: Optional[str] = None - field_name: Optional[str] = None - index_name: Optional[str] = Field( - None, description='No need to set up for now @2021.06.30' - ) - - -class DropPartitionRequest(BaseModel): - base: Optional[common.MsgBase] = None - collection_name: Optional[str] = Field( - None, description='The collection name in milvus' - ) - db_name: Optional[str] = Field(None, description='Not useful for now') - partition_name: Optional[str] = Field( - None, description='The partition name you want to drop' - ) - - -class FlushRequest(BaseModel): - base: Optional[common.MsgBase] = None - collection_names: Optional[List[str]] = None - db_name: Optional[str] = None - - -class FlushResponse(BaseModel): - coll_segIDs: Optional[Dict[str, schema.LongArray]] = None - db_name: Optional[str] = None - status: Optional[common.Status] = None - - -class GetCollectionStatisticsResponse(BaseModel): - stats: Optional[List[common.KeyValuePair]] = Field( - None, description='Collection statistics data' - ) - status: Optional[common.Status] = None - - -class GetCompactionPlansRequest(BaseModel): - compactionID: Optional[int] = None - - -class GetCompactionPlansResponse(BaseModel): - mergeInfos: Optional[List[CompactionMergeInfo]] = None - state: Optional[int] = None - status: Optional[common.Status] = None - - -class GetCompactionStateRequest(BaseModel): - compactionID: Optional[int] = None - - -class GetCompactionStateResponse(BaseModel): - completedPlanNo: Optional[int] = None - executingPlanNo: Optional[int] = None - state: Optional[int] = None - status: Optional[common.Status] = None - timeoutPlanNo: Optional[int] = None - - -class GetFlushStateRequest(BaseModel): - segmentIDs: Optional[List[int]] = Field(None, alias='segment_ids') - - -class GetFlushStateResponse(BaseModel): - flushed: Optional[bool] = None - status: Optional[common.Status] = None - - -class GetImportStateRequest(BaseModel): - task: Optional[int] = Field(None, description='id of an import task') - - -class GetImportStateResponse(BaseModel): - id: Optional[int] = Field(None, description='id of an import task') - id_list: Optional[List[int]] = Field( - None, description='auto generated ids if the primary key is autoid' - ) - infos: Optional[List[common.KeyValuePair]] = Field( - None, - description='more informations about the task, progress percent, file path, failed reason, etc.', - ) - row_count: Optional[int] = Field( - None, - description='if the task is finished, this value is how many rows are imported. if the task is not finished, this value is how many rows are parsed. return 0 if failed.', - ) - state: Optional[int] = Field( - None, description='is this import task finished or not' - ) - status: Optional[common.Status] = None - - -class GetIndexBuildProgressRequest(BaseModel): - base: Optional[common.MsgBase] = None - collection_name: Optional[str] = Field( - None, description='The collection name in milvus' - ) - db_name: Optional[str] = Field(None, description='Not useful for now') - field_name: Optional[str] = Field( - None, description='The vector field name in this collection' - ) - index_name: Optional[str] = Field(None, description='Not useful for now') - - -class GetIndexBuildProgressResponse(BaseModel): - indexed_rows: Optional[int] = None - status: Optional[common.Status] = None - total_rows: Optional[int] = None - - -class GetIndexStateRequest(BaseModel): - base: Optional[common.MsgBase] = None - collection_name: Optional[str] = Field(None, description='must') - db_name: Optional[str] = None - field_name: Optional[str] = None - index_name: Optional[str] = Field( - None, description='No need to set up for now @2021.06.30' - ) - - -class GetIndexStateResponse(BaseModel): - fail_reason: Optional[str] = None - state: Optional[int] = None - status: Optional[common.Status] = None - - -class GetMetricsRequest(BaseModel): - base: Optional[common.MsgBase] = None - request: Optional[str] = Field(None, description='request is of jsonic format') - - -class GetMetricsResponse(BaseModel): - component_name: Optional[str] = Field( - None, description='metrics from which component' - ) - response: Optional[str] = Field(None, description='response is of jsonic format') - status: Optional[common.Status] = None - - -class GetPartitionStatisticsRequest(BaseModel): - base: Optional[common.MsgBase] = None - collection_name: Optional[str] = Field( - None, description='The collection name in milvus' - ) - db_name: Optional[str] = Field(None, description='Not useful for now') - partition_name: Optional[str] = Field( - None, description='The partition name you want to collect statistics' - ) - - -class GetPartitionStatisticsResponse(BaseModel): - stats: Optional[List[common.KeyValuePair]] = None - status: Optional[common.Status] = None - - -class GetPersistentSegmentInfoRequest(BaseModel): - base: Optional[common.MsgBase] = None - collectionName: Optional[str] = Field(None, alias="collection_name", description='must') - dbName: Optional[str] = Field(None, alias="db_name") - - -class GetQuerySegmentInfoRequest(BaseModel): - base: Optional[common.MsgBase] = None - collectionName: Optional[str] = Field(None, alias="collection_name", description='must') - dbName: Optional[str] = Field(None, alias="db_name") - - -class GetReplicasRequest(BaseModel): - base: Optional[common.MsgBase] = None - collectionID: Optional[int] = Field(None, alias="collection_id") - with_shard_nodes: Optional[bool] = None - - -class HasPartitionRequest(BaseModel): - base: Optional[common.MsgBase] = None - collection_name: Optional[str] = Field( - None, description='The collection name in milvus' - ) - db_name: Optional[str] = Field(None, description='Not useful for now') - partition_name: Optional[str] = Field( - None, description='The partition name you want to check' - ) - - -class ImportRequest(BaseModel): - channel_names: Optional[List[str]] = Field( - None, description='channel names for the collection' - ) - collection_name: Optional[str] = Field(None, description='target collection') - files: Optional[List[str]] = Field(None, description='file paths to be imported') - options: Optional[List[common.KeyValuePair]] = Field( - None, description='import options, bucket, etc.' - ) - partition_name: Optional[str] = Field(None, description='target partition') - row_based: Optional[bool] = Field( - None, description='the file is row-based or column-based' - ) - - -class ImportResponse(BaseModel): - status: Optional[common.Status] = None - tasks: Optional[List[int]] = Field(None, description='id array of import tasks') - - -class IndexDescription(BaseModel): - field_name: Optional[str] = Field(None, description='The vector field name') - index_name: Optional[str] = Field(None, description='Index name') - indexID: Optional[int] = Field(None, description='Index id') - params: Optional[List[common.KeyValuePair]] = Field( - None, description='Will return index_type, metric_type, params(like nlist).' - ) - - -class ListCredUsersRequest(BaseModel): - base: Optional[common.MsgBase] = None - - -class ListCredUsersResponse(BaseModel): - status: Optional[common.Status] = None - usernames: Optional[List[str]] = Field(None, description='username array') - - -class ListImportTasksRequest(BaseModel): - pass - - -class ListImportTasksResponse(BaseModel): - status: Optional[common.Status] = None - tasks: Optional[List[GetImportStateResponse]] = Field( - None, description='list of all import tasks' - ) - - -class LoadBalanceRequest(BaseModel): - base: Optional[common.MsgBase] = None - collectionName: Optional[str] = None - dst_nodeIDs: Optional[List[int]] = None - sealed_segmentIDs: Optional[List[int]] = None - src_nodeID: Optional[int] = None - - -class LoadPartitionsRequest(BaseModel): - base: Optional[common.MsgBase] = None - collection_name: Optional[str] = Field( - None, description='The collection name in milvus' - ) - db_name: Optional[str] = Field(None, description='Not useful for now') - partition_names: Optional[List[str]] = Field( - None, description='The partition names you want to load' - ) - replica_number: Optional[int] = Field( - None, description='The replicas number you would load, 1 by default' - ) - - -class ManualCompactionRequest(BaseModel): - collectionID: Optional[int] = None - timetravel: Optional[int] = None - - -class ManualCompactionResponse(BaseModel): - compactionID: Optional[int] = None - status: Optional[common.Status] = None - - -class PersistentSegmentInfo(BaseModel): - collectionID: Optional[int] = None - num_rows: Optional[int] = None - partitionID: Optional[int] = None - segmentID: Optional[int] = None - state: Optional[int] = None - - -class QueryRequest(BaseModel): - base: Optional[common.MsgBase] = None - collection_name: Optional[str] = None - db_name: Optional[str] = None - expr: Optional[str] = None - guarantee_timestamp: Optional[int] = Field(None, description='guarantee_timestamp') - output_fields: Optional[List[str]] = None - partition_names: Optional[List[str]] = None - travel_timestamp: Optional[int] = None - - -class QueryResults(BaseModel): - collection_name: Optional[str] = None - fields_data: Optional[List[schema.FieldData]] = None - status: Optional[common.Status] = None - - -class QuerySegmentInfo(BaseModel): - collectionID: Optional[int] = None - index_name: Optional[str] = None - indexID: Optional[int] = None - mem_size: Optional[int] = None - nodeID: Optional[int] = None - num_rows: Optional[int] = None - partitionID: Optional[int] = None - segmentID: Optional[int] = None - state: Optional[int] = None - - -class ReleasePartitionsRequest(BaseModel): - base: Optional[common.MsgBase] = None - collection_name: Optional[str] = Field( - None, description='The collection name in milvus' - ) - db_name: Optional[str] = Field(None, description='Not useful for now') - partition_names: Optional[List[str]] = Field( - None, description='The partition names you want to release' - ) - - -class SearchRequest(BaseModel): - base: Optional[common.MsgBase] = None - collection_name: Optional[str] = Field(None, description='must') - db_name: Optional[str] = None - dsl: Optional[str] = Field(None, description='must') - dsl_type: Optional[int] = Field(None, description='must') - guarantee_timestamp: Optional[int] = Field(None, description='guarantee_timestamp') - output_fields: Optional[List[str]] = None - partition_names: Optional[List[str]] = Field(None, description='must') - placeholder_group: Optional[List[int]] = Field( - None, description='serialized `PlaceholderGroup`' - ) - search_params: Optional[List[common.KeyValuePair]] = Field(None, description='must') - travel_timestamp: Optional[int] = None - - -class SearchResults(BaseModel): - collection_name: Optional[str] = None - results: Optional[schema.SearchResultData] = None - status: Optional[common.Status] = None - - -class ShardReplica(BaseModel): - dm_channel_name: Optional[str] = None - leader_addr: Optional[str] = Field(None, description='IP:port') - leaderID: Optional[int] = None - node_ids: Optional[List[int]] = Field( - None, - description='optional, DO NOT save it in meta, set it only for GetReplicas()\nif with_shard_nodes is true', - ) - - -class ShowCollectionsResponse(BaseModel): - collection_ids: Optional[List[int]] = Field(None, description='Collection Id array') - collection_names: Optional[List[str]] = Field( - None, description='Collection name array' - ) - created_timestamps: Optional[List[int]] = Field( - None, description='Hybrid timestamps in milvus' - ) - created_utc_timestamps: Optional[List[int]] = Field( - None, description='The utc timestamp calculated by created_timestamp' - ) - inMemory_percentages: Optional[List[int]] = Field( - None, description='Load percentage on querynode when type is InMemory' - ) - status: Optional[common.Status] = None - - -class ShowPartitionsRequest(BaseModel): - base: Optional[common.MsgBase] = None - collection_name: Optional[str] = Field( - None, - description='The collection name you want to describe, you can pass collection_name or collectionID', - ) - collectionID: Optional[int] = Field(None, description='The collection id in milvus') - db_name: Optional[str] = Field(None, description='Not useful for now') - partition_names: Optional[List[str]] = Field( - None, - description="When type is InMemory, will return these patitions' inMemory_percentages.(Optional)", - ) - type: Optional[int] = Field( - None, description='Decide return Loaded partitions or All partitions(Optional)' - ) - - -class ShowPartitionsResponse(BaseModel): - created_timestamps: Optional[List[int]] = Field( - None, description='All hybrid timestamps' - ) - created_utc_timestamps: Optional[List[int]] = Field( - None, description='All utc timestamps calculated by created_timestamps' - ) - inMemory_percentages: Optional[List[int]] = Field( - None, description='Load percentage on querynode' - ) - partition_names: Optional[List[str]] = Field( - None, description='All partition names for this collection' - ) - partitionIDs: Optional[List[int]] = Field( - None, description='All partition ids for this collection' - ) - status: Optional[common.Status] = None - - -class UpdateCredentialRequest(BaseModel): - base: Optional[common.MsgBase] = None - created_utc_timestamps: Optional[int] = Field(None, description='create time') - modified_utc_timestamps: Optional[int] = Field(None, description='modify time') - newPassword: Optional[str] = Field(None, description='new password') - oldPassword: Optional[str] = Field(None, description='old password') - username: Optional[str] = Field(None, description='username') - - -class DescribeCollectionResponse(BaseModel): - aliases: Optional[List[str]] = Field( - None, description='The aliases of this collection' - ) - collection_name: Optional[str] = Field(None, description='The collection name') - collectionID: Optional[int] = Field(None, description='The collection id') - consistency_level: Optional[int] = Field( - None, - description='The consistency level that the collection used, modification is not supported now.', - ) - created_timestamp: Optional[int] = Field( - None, description='Hybrid timestamp in milvus' - ) - created_utc_timestamp: Optional[int] = Field( - None, description='The utc timestamp calculated by created_timestamp' - ) - physical_channel_names: Optional[List[str]] = Field( - None, description='System design related, users should not perceive' - ) - schema_: Optional[schema.CollectionSchema] = Field(None, alias='schema') - shards_num: Optional[int] = Field(None, description='The shards number you set.') - start_positions: Optional[List[common.KeyDataPair]] = Field( - None, description='The message ID/posititon when collection is created' - ) - status: Optional[common.Status] = None - virtual_channel_names: Optional[List[str]] = Field( - None, description='System design related, users should not perceive' - ) - - -class DescribeIndexResponse(BaseModel): - index_descriptions: Optional[List[IndexDescription]] = Field( - None, - description='All index informations, for now only return tha latest index you created for the collection.', - ) - status: Optional[common.Status] = None - - -class GetPersistentSegmentInfoResponse(BaseModel): - infos: Optional[List[PersistentSegmentInfo]] = None - status: Optional[common.Status] = None - - -class GetQuerySegmentInfoResponse(BaseModel): - infos: Optional[List[QuerySegmentInfo]] = None - status: Optional[common.Status] = None - - -class ReplicaInfo(BaseModel): - collectionID: Optional[int] = None - node_ids: Optional[List[int]] = Field(None, description='include leaders') - partition_ids: Optional[List[int]] = Field( - None, description='empty indicates to load collection' - ) - replicaID: Optional[int] = None - shard_replicas: Optional[List[ShardReplica]] = None - - -class GetReplicasResponse(BaseModel): - replicas: Optional[List[ReplicaInfo]] = None - status: Optional[common.Status] = None diff --git a/tests/restful_client/pytest.ini b/tests/restful_client/pytest.ini index da7be90208..b1b55479a5 100644 --- a/tests/restful_client/pytest.ini +++ b/tests/restful_client/pytest.ini @@ -1,12 +1,14 @@ [pytest] - - -addopts = --host 10.101.178.131 --html=/tmp/ci_logs/report.html --self-contained-html -v -# python3 -W ignore -m pytest +addopts = --strict --host 127.0.0.1 --port 19530 --username root --password Milvus --log-cli-level=INFO --capture=no log_format = [%(asctime)s - %(levelname)s - %(name)s]: %(message)s (%(filename)s:%(lineno)s) log_date_format = %Y-%m-%d %H:%M:%S filterwarnings = - ignore::DeprecationWarning \ No newline at end of file + ignore::DeprecationWarning + +markers = + L0 : 'L0 case, high priority' + L1 : 'L1 case, second priority' + diff --git a/tests/restful_client/requirements.txt b/tests/restful_client/requirements.txt index da51bc9912..f243fe3ae9 100644 --- a/tests/restful_client/requirements.txt +++ b/tests/restful_client/requirements.txt @@ -1,2 +1,10 @@ -decorest~=0.1.0 -pydantic~=1.10.2 \ No newline at end of file +requests~=2.26.0 +urllib3==1.26.16 +loguru~=0.5.3 +pytest~=7.2.0 +pyyaml~=6.0 +numpy~=1.24.3 +allure-pytest>=2.8.18 +Faker==19.2.0 +pymilvus~=2.2.9 +scikit-learn~=1.1.3 \ No newline at end of file diff --git a/tests/restful_client/testcases/test_collection_operations.py b/tests/restful_client/testcases/test_collection_operations.py new file mode 100644 index 0000000000..574eccd3e3 --- /dev/null +++ b/tests/restful_client/testcases/test_collection_operations.py @@ -0,0 +1,395 @@ +import datetime +import random +import time +from utils.util_log import test_log as logger +from utils.utils import gen_collection_name +import pytest +from api.milvus import CollectionClient +from base.testbase import TestBase +import threading + + +@pytest.mark.L0 +class TestCreateCollection(TestBase): + + @pytest.mark.parametrize("vector_field", [None, "vector", "emb"]) + @pytest.mark.parametrize("primary_field", [None, "id", "doc_id"]) + @pytest.mark.parametrize("metric_type", ["L2", "IP"]) + @pytest.mark.parametrize("dim", [32, 32768]) + @pytest.mark.parametrize("db_name", ["prod", "default"]) + def test_create_collections_default(self, dim, metric_type, primary_field, vector_field, db_name): + """ + target: test create collection + method: create a collection with a simple schema + expected: create collection success + """ + self.create_database(db_name) + name = gen_collection_name() + dim = 128 + client = self.collection_client + client.db_name = db_name + payload = { + "collectionName": name, + "dimension": dim, + "metricType": metric_type, + "primaryField": primary_field, + "vectorField": vector_field, + } + if primary_field is None: + del payload["primaryField"] + if vector_field is None: + del payload["vectorField"] + rsp = client.collection_create(payload) + assert rsp['code'] == 200 + rsp = client.collection_list() + + all_collections = rsp['data'] + assert name in all_collections + # describe collection + rsp = client.collection_describe(name) + assert rsp['code'] == 200 + assert rsp['data']['collectionName'] == name + + def test_create_collections_concurrent_with_same_param(self): + """ + target: test create collection with same param + method: concurrent create collections with same param with multi thread + expected: create collections all success + """ + concurrent_rsp = [] + + def create_collection(c_name, vector_dim, c_metric_type): + collection_payload = { + "collectionName": c_name, + "dimension": vector_dim, + "metricType": c_metric_type, + } + rsp = client.collection_create(collection_payload) + concurrent_rsp.append(rsp) + logger.info(rsp) + name = gen_collection_name() + dim = 128 + metric_type = "L2" + client = self.collection_client + threads = [] + for i in range(10): + t = threading.Thread(target=create_collection, args=(name, dim, metric_type,)) + threads.append(t) + for t in threads: + t.start() + for t in threads: + t.join() + time.sleep(10) + success_cnt = 0 + for rsp in concurrent_rsp: + if rsp["code"] == 200: + success_cnt += 1 + logger.info(concurrent_rsp) + assert success_cnt == 10 + rsp = client.collection_list() + all_collections = rsp['data'] + assert name in all_collections + # describe collection + rsp = client.collection_describe(name) + assert rsp['code'] == 200 + assert rsp['data']['collectionName'] == name + assert f"FloatVector({dim})" in str(rsp['data']['fields']) + + def test_create_collections_concurrent_with_different_param(self): + """ + target: test create collection with different param + method: concurrent create collections with different param with multi thread + expected: only one collection can success + """ + concurrent_rsp = [] + + def create_collection(c_name, vector_dim, c_metric_type): + collection_payload = { + "collectionName": c_name, + "dimension": vector_dim, + "metricType": c_metric_type, + } + rsp = client.collection_create(collection_payload) + concurrent_rsp.append(rsp) + logger.info(rsp) + name = gen_collection_name() + dim = 128 + client = self.collection_client + threads = [] + for i in range(0, 5): + t = threading.Thread(target=create_collection, args=(name, dim + i, "L2",)) + threads.append(t) + for i in range(5, 10): + t = threading.Thread(target=create_collection, args=(name, dim + i, "IP",)) + threads.append(t) + for t in threads: + t.start() + for t in threads: + t.join() + time.sleep(10) + success_cnt = 0 + for rsp in concurrent_rsp: + if rsp["code"] == 200: + success_cnt += 1 + logger.info(concurrent_rsp) + assert success_cnt == 1 + rsp = client.collection_list() + all_collections = rsp['data'] + assert name in all_collections + # describe collection + rsp = client.collection_describe(name) + assert rsp['code'] == 200 + assert rsp['data']['collectionName'] == name + + def test_create_collections_with_invalid_api_key(self): + """ + target: test create collection with invalid api key(wrong username and password) + method: create collections with invalid api key + expected: create collection failed + """ + name = gen_collection_name() + dim = 128 + client = self.collection_client + client.api_key = "illegal_api_key" + payload = { + "collectionName": name, + "dimension": dim, + } + rsp = client.collection_create(payload) + assert rsp['code'] == 1800 + + @pytest.mark.parametrize("name", [" ", "test_collection_" * 100, "test collection", "test/collection", "test\collection"]) + def test_create_collections_with_invalid_collection_name(self, name): + """ + target: test create collection with invalid collection name + method: create collections with invalid collection name + expected: create collection failed with right error message + """ + dim = 128 + client = self.collection_client + payload = { + "collectionName": name, + "dimension": dim, + } + rsp = client.collection_create(payload) + assert rsp['code'] == 1 + + +@pytest.mark.L0 +class TestListCollections(TestBase): + + def test_list_collections_default(self): + """ + target: test list collection with a simple schema + method: create collections and list them + expected: created collections are in list + """ + client = self.collection_client + name_list = [] + for i in range(2): + name = gen_collection_name() + dim = 128 + payload = { + "collectionName": name, + "dimension": dim, + } + time.sleep(1) + rsp = client.collection_create(payload) + assert rsp['code'] == 200 + name_list.append(name) + rsp = client.collection_list() + all_collections = rsp['data'] + for name in name_list: + assert name in all_collections + + def test_list_collections_with_invalid_api_key(self): + """ + target: test list collection with an invalid api key + method: list collection with invalid api key + expected: raise error with right error code and message + """ + client = self.collection_client + name_list = [] + for i in range(2): + name = gen_collection_name() + dim = 128 + payload = { + "collectionName": name, + "dimension": dim, + } + time.sleep(1) + rsp = client.collection_create(payload) + assert rsp['code'] == 200 + name_list.append(name) + client = self.collection_client + client.api_key = "illegal_api_key" + rsp = client.collection_list() + assert rsp['code'] == 1800 + + +@pytest.mark.L0 +class TestDescribeCollection(TestBase): + + + def test_describe_collections_default(self): + """ + target: test describe collection with a simple schema + method: describe collection + expected: info of description is same with param passed to create collection + """ + name = gen_collection_name() + dim = 128 + client = self.collection_client + payload = { + "collectionName": name, + "dimension": dim, + } + rsp = client.collection_create(payload) + assert rsp['code'] == 200 + rsp = client.collection_list() + all_collections = rsp['data'] + assert name in all_collections + # describe collection + rsp = client.collection_describe(name) + assert rsp['code'] == 200 + assert rsp['data']['collectionName'] == name + assert f"FloatVector({dim})" in str(rsp['data']['fields']) + + def test_describe_collections_with_invalid_api_key(self): + """ + target: test describe collection with invalid api key + method: describe collection with invalid api key + expected: raise error with right error code and message + """ + name = gen_collection_name() + dim = 128 + client = self.collection_client + payload = { + "collectionName": name, + "dimension": dim, + } + rsp = client.collection_create(payload) + assert rsp['code'] == 200 + rsp = client.collection_list() + all_collections = rsp['data'] + assert name in all_collections + # describe collection + illegal_client = CollectionClient(self.url, "illegal_api_key") + rsp = illegal_client.collection_describe(name) + assert rsp['code'] == 1800 + + def test_describe_collections_with_invalid_collection_name(self): + """ + target: test describe collection with invalid collection name + method: describe collection with invalid collection name + expected: raise error with right error code and message + """ + name = gen_collection_name() + dim = 128 + client = self.collection_client + payload = { + "collectionName": name, + "dimension": dim, + } + rsp = client.collection_create(payload) + assert rsp['code'] == 200 + rsp = client.collection_list() + all_collections = rsp['data'] + assert name in all_collections + # describe collection + invalid_name = "invalid_name" + rsp = client.collection_describe(invalid_name) + assert rsp['code'] == 1 + + +@pytest.mark.L0 +class TestDropCollection(TestBase): + + def test_drop_collections_default(self): + """ + Drop a collection with a simple schema + target: test drop collection with a simple schema + method: drop collection + expected: dropped collection was not in collection list + """ + clo_list = [] + for i in range(5): + time.sleep(1) + name = 'test_collection_' + datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f_%f") + payload = { + "collectionName": name, + "dimension": 128, + } + rsp = self.collection_client.collection_create(payload) + assert rsp['code'] == 200 + clo_list.append(name) + rsp = self.collection_client.collection_list() + all_collections = rsp['data'] + for name in clo_list: + assert name in all_collections + for name in clo_list: + time.sleep(0.2) + payload = { + "collectionName": name, + } + rsp = self.collection_client.collection_drop(payload) + assert rsp['code'] == 200 + rsp = self.collection_client.collection_list() + all_collections = rsp['data'] + for name in clo_list: + assert name not in all_collections + + def test_drop_collections_with_invalid_api_key(self): + """ + target: test drop collection with invalid api key + method: drop collection with invalid api key + expected: raise error with right error code and message; collection still in collection list + """ + name = gen_collection_name() + dim = 128 + client = self.collection_client + payload = { + "collectionName": name, + "dimension": dim, + } + rsp = client.collection_create(payload) + assert rsp['code'] == 200 + rsp = client.collection_list() + all_collections = rsp['data'] + assert name in all_collections + # drop collection + payload = { + "collectionName": name, + } + illegal_client = CollectionClient(self.url, "invalid_api_key") + rsp = illegal_client.collection_drop(payload) + assert rsp['code'] == 1800 + rsp = client.collection_list() + all_collections = rsp['data'] + assert name in all_collections + + def test_drop_collections_with_invalid_collection_name(self): + """ + target: test drop collection with invalid collection name + method: drop collection with invalid collection name + expected: raise error with right error code and message + """ + name = gen_collection_name() + dim = 128 + client = self.collection_client + payload = { + "collectionName": name, + "dimension": dim, + } + rsp = client.collection_create(payload) + assert rsp['code'] == 200 + rsp = client.collection_list() + all_collections = rsp['data'] + assert name in all_collections + # drop collection + invalid_name = "invalid_name" + payload = { + "collectionName": invalid_name, + } + rsp = client.collection_drop(payload) + assert rsp['code'] == 100 diff --git a/tests/restful_client/testcases/test_e2e.py b/tests/restful_client/testcases/test_e2e.py deleted file mode 100644 index 525ad9913e..0000000000 --- a/tests/restful_client/testcases/test_e2e.py +++ /dev/null @@ -1,49 +0,0 @@ -from time import sleep -from common import common_type as ct -from common import common_func as cf -from base.client_base import TestBase -from utils.util_log import test_log as log - - -class TestDefault(TestBase): - - def test_e2e(self): - collection_name, schema = self.init_collection() - nb = ct.default_nb - # insert - res = self.entity_service.insert(collection_name=collection_name, fields_data=cf.gen_fields_data(schema, nb=nb), - num_rows=nb) - log.info(f"insert {nb} rows into collection {collection_name}, response: {res}") - # flush - res = self.entity_service.flush(collection_names=[collection_name]) - log.info(f"flush collection {collection_name}, response: {res}") - # create index for vector field - vector_field_name = cf.get_vector_field(schema) - vector_index_params = cf.gen_index_params(index_type="HNSW") - res = self.index_service.create_index(collection_name=collection_name, field_name=vector_field_name, - extra_params=vector_index_params) - log.info(f"create index for vector field {vector_field_name}, response: {res}") - # load - res = self.collection_service.load_collection(collection_name=collection_name) - log.info(f"load collection {collection_name}, response: {res}") - - sleep(5) - # search - vectors = cf.gen_vectors(nq=ct.default_nq, schema=schema) - res = self.entity_service.search(collection_name=collection_name, vectors=vectors, - output_fields=[ct.default_int64_field_name], - search_params=cf.gen_search_params()) - log.info(f"search collection {collection_name}, response: {res}") - - # hybrid search - res = self.entity_service.search(collection_name=collection_name, vectors=vectors, - output_fields=[ct.default_int64_field_name], - search_params=cf.gen_search_params(), - dsl=ct.default_dsl) - log.info(f"hybrid search collection {collection_name}, response: {res}") - # query - res = self.entity_service.query(collection_name=collection_name, expr=ct.default_expr) - - log.info(f"query collection {collection_name}, response: {res}") - - diff --git a/tests/restful_client/testcases/test_vector_operations.py b/tests/restful_client/testcases/test_vector_operations.py new file mode 100644 index 0000000000..d62d5783b7 --- /dev/null +++ b/tests/restful_client/testcases/test_vector_operations.py @@ -0,0 +1,984 @@ +import datetime +import random +import time +from sklearn import preprocessing +import numpy as np +import sys +import json +import time +from utils import constant +from utils.utils import gen_collection_name +from utils.util_log import test_log as logger +import pytest +from api.milvus import VectorClient +from base.testbase import TestBase +from utils.utils import (get_data_by_fields, get_data_by_payload, get_common_fields_by_data) + + +class TestInsertVector(TestBase): + + @pytest.mark.L0 + @pytest.mark.parametrize("insert_round", [2, 1]) + @pytest.mark.parametrize("nb", [100, 10, 1]) + @pytest.mark.parametrize("dim", [32, 128]) + @pytest.mark.parametrize("primary_field", ["id", "url"]) + @pytest.mark.parametrize("vector_field", ["vector", "embedding"]) + @pytest.mark.parametrize("db_name", ["prod", "default"]) + def test_insert_vector_with_simple_payload(self, db_name, vector_field, primary_field, nb, dim, insert_round): + """ + Insert a vector with a simple payload + """ + self.update_database(db_name=db_name) + # create a collection + name = gen_collection_name() + collection_payload = { + "collectionName": name, + "dimension": dim, + "primaryField": primary_field, + "vectorField": vector_field, + } + rsp = self.collection_client.collection_create(collection_payload) + assert rsp['code'] == 200 + rsp = self.collection_client.collection_describe(name) + logger.info(f"rsp: {rsp}") + assert rsp['code'] == 200 + # insert data + for i in range(insert_round): + data = get_data_by_payload(collection_payload, nb) + payload = { + "collectionName": name, + "data": data, + } + body_size = sys.getsizeof(json.dumps(payload)) + logger.info(f"body size: {body_size / 1024 / 1024} MB") + rsp = self.vector_client.vector_insert(payload) + assert rsp['code'] == 200 + assert rsp['data']['insertCount'] == nb + logger.info("finished") + + @pytest.mark.L0 + @pytest.mark.parametrize("insert_round", [10]) + def test_insert_vector_with_multi_round(self, insert_round): + """ + Insert a vector with a simple payload + """ + # create a collection + name = gen_collection_name() + collection_payload = { + "collectionName": name, + "dimension": 768, + } + rsp = self.collection_client.collection_create(collection_payload) + assert rsp['code'] == 200 + rsp = self.collection_client.collection_describe(name) + logger.info(f"rsp: {rsp}") + assert rsp['code'] == 200 + # insert data + nb = 300 + for i in range(insert_round): + data = get_data_by_payload(collection_payload, nb) + payload = { + "collectionName": name, + "data": data, + } + body_size = sys.getsizeof(json.dumps(payload)) + logger.info(f"body size: {body_size / 1024 / 1024} MB") + rsp = self.vector_client.vector_insert(payload) + assert rsp['code'] == 200 + assert rsp['data']['insertCount'] == nb + logger.info("finished") + + def test_insert_vector_with_invalid_api_key(self): + """ + Insert a vector with invalid api key + """ + # create a collection + name = gen_collection_name() + dim = 128 + payload = { + "collectionName": name, + "dimension": dim, + } + rsp = self.collection_client.collection_create(payload) + assert rsp['code'] == 200 + rsp = self.collection_client.collection_describe(name) + assert rsp['code'] == 200 + # insert data + nb = 10 + data = [ + { + "vector": [np.float64(random.random()) for _ in range(dim)], + } for _ in range(nb) + ] + payload = { + "collectionName": name, + "data": data, + } + body_size = sys.getsizeof(json.dumps(payload)) + logger.info(f"body size: {body_size / 1024 / 1024} MB") + client = self.vector_client + client.api_key = "invalid_api_key" + rsp = client.vector_insert(payload) + assert rsp['code'] == 1800 + + def test_insert_vector_with_invalid_collection_name(self): + """ + Insert a vector with an invalid collection name + """ + + # create a collection + name = gen_collection_name() + dim = 128 + payload = { + "collectionName": name, + "dimension": dim, + } + rsp = self.collection_client.collection_create(payload) + assert rsp['code'] == 200 + rsp = self.collection_client.collection_describe(name) + assert rsp['code'] == 200 + # insert data + nb = 100 + data = get_data_by_payload(payload, nb) + payload = { + "collectionName": "invalid_collection_name", + "data": data, + } + body_size = sys.getsizeof(json.dumps(payload)) + logger.info(f"body size: {body_size / 1024 / 1024} MB") + rsp = self.vector_client.vector_insert(payload) + assert rsp['code'] == 1 + + def test_insert_vector_with_invalid_database_name(self): + """ + Insert a vector with an invalid database name + """ + # create a collection + name = gen_collection_name() + dim = 128 + payload = { + "collectionName": name, + "dimension": dim, + } + rsp = self.collection_client.collection_create(payload) + assert rsp['code'] == 200 + rsp = self.collection_client.collection_describe(name) + assert rsp['code'] == 200 + # insert data + nb = 10 + data = get_data_by_payload(payload, nb) + payload = { + "collectionName": name, + "data": data, + } + body_size = sys.getsizeof(json.dumps(payload)) + logger.info(f"body size: {body_size / 1024 / 1024} MB") + success = False + rsp = self.vector_client.vector_insert(payload, db_name="invalid_database") + assert rsp['code'] == 800 + + def test_insert_vector_with_mismatch_dim(self): + """ + Insert a vector with mismatch dim + """ + # create a collection + name = gen_collection_name() + dim = 32 + payload = { + "collectionName": name, + "dimension": dim, + } + rsp = self.collection_client.collection_create(payload) + assert rsp['code'] == 200 + rsp = self.collection_client.collection_describe(name) + assert rsp['code'] == 200 + # insert data + nb = 1 + data = [ + { + "vector": [np.float64(random.random()) for _ in range(dim + 1)], + } for i in range(nb) + ] + payload = { + "collectionName": name, + "data": data, + } + body_size = sys.getsizeof(json.dumps(payload)) + logger.info(f"body size: {body_size / 1024 / 1024} MB") + rsp = self.vector_client.vector_insert(payload) + assert rsp['code'] == 1804 + assert rsp['message'] == "fail to deal the insert data" + + +class TestSearchVector(TestBase): + + @pytest.mark.L0 + @pytest.mark.parametrize("metric_type", ["IP", "L2"]) + def test_search_vector_with_simple_payload(self, metric_type): + """ + Search a vector with a simple payload + """ + name = gen_collection_name() + self.name = name + self.init_collection(name, metric_type=metric_type) + + # search data + dim = 128 + vector_to_search = preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist() + payload = { + "collectionName": name, + "vector": vector_to_search, + } + rsp = self.vector_client.vector_search(payload) + assert rsp['code'] == 200 + res = rsp['data'] + logger.info(f"res: {len(res)}") + limit = int(payload.get("limit", 100)) + assert len(res) == limit + ids = [item['id'] for item in res] + assert len(ids) == len(set(ids)) + distance = [item['distance'] for item in res] + if metric_type == "L2": + assert distance == sorted(distance) + if metric_type == "IP": + assert distance == sorted(distance, reverse=True) + + @pytest.mark.L0 + @pytest.mark.parametrize("sum_limit_offset", [16384, 16385]) + def test_search_vector_with_exceed_sum_limit_offset(self, sum_limit_offset): + """ + Search a vector with a simple payload + """ + max_search_sum_limit_offset = constant.MAX_SUM_OFFSET_AND_LIMIT + name = gen_collection_name() + self.name = name + nb = sum_limit_offset + 2000 + metric_type = "IP" + limit = 100 + self.init_collection(name, metric_type=metric_type, nb=nb, batch_size=2000) + + # search data + dim = 128 + vector_to_search = preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist() + payload = { + "collectionName": name, + "vector": vector_to_search, + "limit": limit, + "offset": sum_limit_offset-limit, + } + rsp = self.vector_client.vector_search(payload) + if sum_limit_offset > max_search_sum_limit_offset: + assert rsp['code'] == 1 + return + assert rsp['code'] == 200 + res = rsp['data'] + logger.info(f"res: {len(res)}") + limit = int(payload.get("limit", 100)) + assert len(res) == limit + ids = [item['id'] for item in res] + assert len(ids) == len(set(ids)) + distance = [item['distance'] for item in res] + if metric_type == "L2": + assert distance == sorted(distance) + if metric_type == "IP": + assert distance == sorted(distance, reverse=True) + + @pytest.mark.L0 + @pytest.mark.parametrize("level", [0, 1, 2]) + @pytest.mark.parametrize("offset", [0, 10, 100]) + @pytest.mark.parametrize("limit", [1, 100]) + @pytest.mark.parametrize("metric_type", ["L2", "IP"]) + def test_search_vector_with_complex_payload(self, limit, offset, level, metric_type): + """ + Search a vector with a simple payload + """ + name = gen_collection_name() + self.name = name + nb = limit + offset + 100 + dim = 128 + schema_payload, data = self.init_collection(name, dim=dim, nb=nb, metric_type=metric_type) + vector_field = schema_payload.get("vectorField") + # search data + vector_to_search = preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist() + output_fields = get_common_fields_by_data(data, exclude_fields=[vector_field]) + payload = { + "collectionName": name, + "vector": vector_to_search, + "outputFields": output_fields, + "filter": "uid >= 0", + "limit": limit, + "offset": offset, + } + rsp = self.vector_client.vector_search(payload) + if offset + limit > constant.MAX_SUM_OFFSET_AND_LIMIT: + assert rsp['code'] == 90126 + return + assert rsp['code'] == 200 + res = rsp['data'] + logger.info(f"res: {len(res)}") + assert len(res) == limit + for item in res: + assert item.get("uid") >= 0 + for field in output_fields: + assert field in item + + @pytest.mark.L0 + @pytest.mark.parametrize("filter_expr", ["uid >= 0", "uid >= 0 and uid < 100", "uid in [1,2,3]"]) + def test_search_vector_with_complex_int_filter(self, filter_expr): + """ + Search a vector with a simple payload + """ + name = gen_collection_name() + self.name = name + nb = 200 + dim = 128 + limit = 100 + schema_payload, data = self.init_collection(name, dim=dim, nb=nb) + vector_field = schema_payload.get("vectorField") + # search data + vector_to_search = preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist() + output_fields = get_common_fields_by_data(data, exclude_fields=[vector_field]) + payload = { + "collectionName": name, + "vector": vector_to_search, + "outputFields": output_fields, + "filter": filter_expr, + "limit": limit, + "offset": 0, + } + rsp = self.vector_client.vector_search(payload) + assert rsp['code'] == 200 + res = rsp['data'] + logger.info(f"res: {len(res)}") + assert len(res) <= limit + for item in res: + uid = item.get("uid") + eval(filter_expr) + + @pytest.mark.L0 + @pytest.mark.parametrize("filter_expr", ["name > \"placeholder\"", "name like \"placeholder%\""]) + def test_search_vector_with_complex_varchar_filter(self, filter_expr): + """ + Search a vector with a simple payload + """ + name = gen_collection_name() + self.name = name + nb = 200 + dim = 128 + limit = 100 + schema_payload, data = self.init_collection(name, dim=dim, nb=nb) + names = [] + for item in data: + names.append(item.get("name")) + names.sort() + logger.info(f"names: {names}") + mid = len(names) // 2 + prefix = names[mid][0:2] + vector_field = schema_payload.get("vectorField") + # search data + vector_to_search = preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist() + output_fields = get_common_fields_by_data(data, exclude_fields=[vector_field]) + filter_expr = filter_expr.replace("placeholder", prefix) + logger.info(f"filter_expr: {filter_expr}") + payload = { + "collectionName": name, + "vector": vector_to_search, + "outputFields": output_fields, + "filter": filter_expr, + "limit": limit, + "offset": 0, + } + rsp = self.vector_client.vector_search(payload) + assert rsp['code'] == 200 + res = rsp['data'] + logger.info(f"res: {len(res)}") + assert len(res) <= limit + for item in res: + name = item.get("name") + logger.info(f"name: {name}") + if ">" in filter_expr: + assert name > prefix + if "like" in filter_expr: + assert name.startswith(prefix) + + @pytest.mark.L0 + @pytest.mark.parametrize("filter_expr", ["uid < 100 and name > \"placeholder\"", + "uid < 100 and name like \"placeholder%\"" + ]) + def test_search_vector_with_complex_int64_varchar_and_filter(self, filter_expr): + """ + Search a vector with a simple payload + """ + name = gen_collection_name() + self.name = name + nb = 200 + dim = 128 + limit = 100 + schema_payload, data = self.init_collection(name, dim=dim, nb=nb) + names = [] + for item in data: + names.append(item.get("name")) + names.sort() + logger.info(f"names: {names}") + mid = len(names) // 2 + prefix = names[mid][0:2] + vector_field = schema_payload.get("vectorField") + # search data + vector_to_search = preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist() + output_fields = get_common_fields_by_data(data, exclude_fields=[vector_field]) + filter_expr = filter_expr.replace("placeholder", prefix) + logger.info(f"filter_expr: {filter_expr}") + payload = { + "collectionName": name, + "vector": vector_to_search, + "outputFields": output_fields, + "filter": filter_expr, + "limit": limit, + "offset": 0, + } + rsp = self.vector_client.vector_search(payload) + assert rsp['code'] == 200 + res = rsp['data'] + logger.info(f"res: {len(res)}") + assert len(res) <= limit + for item in res: + uid = item.get("uid") + name = item.get("name") + logger.info(f"name: {name}") + uid_expr = filter_expr.split("and")[0] + assert eval(uid_expr) is True + varchar_expr = filter_expr.split("and")[1] + if ">" in varchar_expr: + assert name > prefix + if "like" in varchar_expr: + assert name.startswith(prefix) + + @pytest.mark.parametrize("limit", [0, 16385]) + def test_search_vector_with_invalid_limit(self, limit): + """ + Search a vector with a simple payload + """ + name = gen_collection_name() + self.name = name + dim = 128 + schema_payload, data = self.init_collection(name, dim=dim) + vector_field = schema_payload.get("vectorField") + # search data + vector_to_search = preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist() + output_fields = get_common_fields_by_data(data, exclude_fields=[vector_field]) + payload = { + "collectionName": name, + "vector": vector_to_search, + "outputFields": output_fields, + "filter": "uid >= 0", + "limit": limit, + "offset": 0, + } + rsp = self.vector_client.vector_search(payload) + assert rsp['code'] == 1 + + @pytest.mark.parametrize("offset", [-1, 100_001]) + def test_search_vector_with_invalid_offset(self, offset): + """ + Search a vector with a simple payload + """ + name = gen_collection_name() + self.name = name + dim = 128 + schema_payload, data = self.init_collection(name, dim=dim) + vector_field = schema_payload.get("vectorField") + # search data + dim = 128 + vector_to_search = preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist() + output_fields = get_common_fields_by_data(data, exclude_fields=[vector_field]) + payload = { + "collectionName": name, + "vector": vector_to_search, + "outputFields": output_fields, + "filter": "uid >= 0", + "limit": 100, + "offset": offset, + } + rsp = self.vector_client.vector_search(payload) + assert rsp['code'] == 1 + + def test_search_vector_with_illegal_api_key(self): + """ + Search a vector with an illegal api key + """ + pass + + def test_search_vector_with_invalid_collection_name(self): + """ + Search a vector with an invalid collection name + """ + pass + + def test_search_vector_with_invalid_output_field(self): + """ + Search a vector with an invalid output field + """ + pass + + @pytest.mark.parametrize("invalid_expr", ["invalid_field > 0", "12-s", "中文", "a", " "]) + def test_search_vector_with_invalid_expression(self, invalid_expr): + """ + Search a vector with an invalid expression + """ + pass + + def test_search_vector_with_invalid_vector_field(self): + """ + Search a vector with an invalid vector field for ann search + """ + pass + + @pytest.mark.parametrize("dim_offset", [1, -1]) + def test_search_vector_with_mismatch_vector_dim(self, dim_offset): + """ + Search a vector with a mismatch vector dim + """ + pass + + +class TestQueryVector(TestBase): + + @pytest.mark.L0 + @pytest.mark.parametrize("expr", ["10+20 <= uid < 20+30", "uid in [1,2,3,4]", + "uid > 0", "uid >= 0", "uid > 0", + "uid > -100 and uid < 100"]) + @pytest.mark.parametrize("include_output_fields", [True, False]) + @pytest.mark.parametrize("partial_fields", [True, False]) + def test_query_vector_with_int64_filter(self, expr, include_output_fields, partial_fields): + """ + Query a vector with a simple payload + """ + name = gen_collection_name() + self.name = name + schema_payload, data = self.init_collection(name) + output_fields = get_common_fields_by_data(data) + if partial_fields: + output_fields = output_fields[:len(output_fields) // 2] + if "uid" not in output_fields: + output_fields.append("uid") + else: + output_fields = output_fields + + # query data + payload = { + "collectionName": name, + "filter": expr, + "limit": 100, + "offset": 0, + "outputFields": output_fields + } + if not include_output_fields: + payload.pop("outputFields") + if 'vector' in output_fields: + output_fields.remove("vector") + time.sleep(5) + rsp = self.vector_client.vector_query(payload) + assert rsp['code'] == 200 + res = rsp['data'] + logger.info(f"res: {len(res)}") + for r in res: + uid = r['uid'] + assert eval(expr) is True + for field in output_fields: + assert field in r + + @pytest.mark.L0 + @pytest.mark.parametrize("filter_expr", ["name > \"placeholder\"", "name like \"placeholder%\""]) + @pytest.mark.parametrize("include_output_fields", [True, False]) + def test_query_vector_with_varchar_filter(self, filter_expr, include_output_fields): + """ + Query a vector with a complex payload + """ + name = gen_collection_name() + self.name = name + nb = 200 + dim = 128 + limit = 100 + schema_payload, data = self.init_collection(name, dim=dim, nb=nb) + names = [] + for item in data: + names.append(item.get("name")) + names.sort() + logger.info(f"names: {names}") + mid = len(names) // 2 + prefix = names[mid][0:2] + # search data + output_fields = get_common_fields_by_data(data) + filter_expr = filter_expr.replace("placeholder", prefix) + logger.info(f"filter_expr: {filter_expr}") + payload = { + "collectionName": name, + "outputFields": output_fields, + "filter": filter_expr, + "limit": limit, + "offset": 0, + } + if not include_output_fields: + payload.pop("outputFields") + rsp = self.vector_client.vector_query(payload) + assert rsp['code'] == 200 + res = rsp['data'] + logger.info(f"res: {len(res)}") + assert len(res) <= limit + for item in res: + name = item.get("name") + logger.info(f"name: {name}") + if ">" in filter_expr: + assert name > prefix + if "like" in filter_expr: + assert name.startswith(prefix) + + @pytest.mark.parametrize("sum_of_limit_offset", [16384, 16385]) + def test_query_vector_with_large_sum_of_limit_offset(self, sum_of_limit_offset): + """ + Query a vector with sum of limit and offset larger than max value + """ + max_sum_of_limit_offset = 16384 + name = gen_collection_name() + filter_expr = "name > \"placeholder\"" + self.name = name + nb = 200 + dim = 128 + limit = 100 + offset = sum_of_limit_offset - limit + schema_payload, data = self.init_collection(name, dim=dim, nb=nb) + names = [] + for item in data: + names.append(item.get("name")) + names.sort() + logger.info(f"names: {names}") + mid = len(names) // 2 + prefix = names[mid][0:2] + # search data + output_fields = get_common_fields_by_data(data) + filter_expr = filter_expr.replace("placeholder", prefix) + logger.info(f"filter_expr: {filter_expr}") + payload = { + "collectionName": name, + "outputFields": output_fields, + "filter": filter_expr, + "limit": limit, + "offset": offset, + } + rsp = self.vector_client.vector_query(payload) + if sum_of_limit_offset > max_sum_of_limit_offset: + assert rsp['code'] == 1 + return + assert rsp['code'] == 200 + res = rsp['data'] + logger.info(f"res: {len(res)}") + assert len(res) <= limit + for item in res: + name = item.get("name") + logger.info(f"name: {name}") + if ">" in filter_expr: + assert name > prefix + if "like" in filter_expr: + assert name.startswith(prefix) + + +class TestGetVector(TestBase): + + @pytest.mark.L0 + def test_get_vector_with_simple_payload(self): + """ + Search a vector with a simple payload + """ + name = gen_collection_name() + self.name = name + self.init_collection(name) + + # search data + dim = 128 + vector_to_search = preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist() + payload = { + "collectionName": name, + "vector": vector_to_search, + } + rsp = self.vector_client.vector_search(payload) + assert rsp['code'] == 200 + res = rsp['data'] + logger.info(f"res: {len(res)}") + limit = int(payload.get("limit", 100)) + assert len(res) == limit + ids = [item['id'] for item in res] + assert len(ids) == len(set(ids)) + payload = { + "collectionName": name, + "outputFields": ["*"], + "id": ids[0], + } + rsp = self.vector_client.vector_get(payload) + assert rsp['code'] == 200 + res = rsp['data'] + logger.info(f"res: {res}") + logger.info(f"res: {len(res)}") + for item in res: + assert item['id'] == ids[0] + + @pytest.mark.L0 + @pytest.mark.parametrize("id_field_type", ["list", "one"]) + @pytest.mark.parametrize("include_invalid_id", [True, False]) + @pytest.mark.parametrize("include_output_fields", [True, False]) + def test_get_vector_complex(self, id_field_type, include_output_fields, include_invalid_id): + name = gen_collection_name() + self.name = name + nb = 200 + dim = 128 + schema_payload, data = self.init_collection(name, dim=dim, nb=nb) + output_fields = get_common_fields_by_data(data) + uids = [] + for item in data: + uids.append(item.get("uid")) + payload = { + "collectionName": name, + "outputFields": output_fields, + "filter": f"uid in {uids}", + } + rsp = self.vector_client.vector_query(payload) + assert rsp['code'] == 200 + res = rsp['data'] + logger.info(f"res: {len(res)}") + ids = [] + for r in res: + ids.append(r['id']) + logger.info(f"ids: {len(ids)}") + id_to_get = None + if id_field_type == "list": + id_to_get = ids + if id_field_type == "one": + id_to_get = ids[0] + if include_invalid_id: + if isinstance(id_to_get, list): + id_to_get[-1] = 0 + else: + id_to_get = 0 + # get by id list + payload = { + "collectionName": name, + "outputFields": output_fields, + "id": id_to_get + } + rsp = self.vector_client.vector_get(payload) + assert rsp['code'] == 200 + res = rsp['data'] + if isinstance(id_to_get, list): + if include_invalid_id: + assert len(res) == len(id_to_get) - 1 + else: + assert len(res) == len(id_to_get) + else: + if include_invalid_id: + assert len(res) == 0 + else: + assert len(res) == 1 + for r in rsp['data']: + if isinstance(id_to_get, list): + assert r['id'] in id_to_get + else: + assert r['id'] == id_to_get + if include_output_fields: + for field in output_fields: + assert field in r + + +class TestDeleteVector(TestBase): + + @pytest.mark.L0 + @pytest.mark.parametrize("include_invalid_id", [True, False]) + @pytest.mark.parametrize("id_field_type", ["list", "one"]) + def test_delete_vector_default(self, id_field_type, include_invalid_id): + name = gen_collection_name() + self.name = name + nb = 200 + dim = 128 + schema_payload, data = self.init_collection(name, dim=dim, nb=nb) + time.sleep(1) + output_fields = get_common_fields_by_data(data) + uids = [] + for item in data: + uids.append(item.get("uid")) + payload = { + "collectionName": name, + "outputFields": output_fields, + "filter": f"uid in {uids}", + } + rsp = self.vector_client.vector_query(payload) + assert rsp['code'] == 200 + res = rsp['data'] + logger.info(f"res: {len(res)}") + ids = [] + for r in res: + ids.append(r['id']) + logger.info(f"ids: {len(ids)}") + id_to_get = None + if id_field_type == "list": + id_to_get = ids + if id_field_type == "one": + id_to_get = ids[0] + if include_invalid_id: + if isinstance(id_to_get, list): + id_to_get.append(0) + else: + id_to_get = 0 + if isinstance(id_to_get, list): + if len(id_to_get) >= 100: + id_to_get = id_to_get[-100:] + # delete by id list + payload = { + "collectionName": name, + "id": id_to_get + } + rsp = self.vector_client.vector_delete(payload) + assert rsp['code'] == 200 + logger.info(f"delete res: {rsp}") + + # verify data deleted + if not isinstance(id_to_get, list): + id_to_get = [id_to_get] + payload = { + "collectionName": name, + "filter": f"id in {id_to_get}", + } + time.sleep(5) + rsp = self.vector_client.vector_query(payload) + assert rsp['code'] == 200 + assert len(rsp['data']) == 0 + + def test_delete_vector_with_invalid_api_key(self): + """ + Delete a vector with an invalid api key + """ + name = gen_collection_name() + self.name = name + nb = 200 + dim = 128 + schema_payload, data = self.init_collection(name, dim=dim, nb=nb) + output_fields = get_common_fields_by_data(data) + uids = [] + for item in data: + uids.append(item.get("uid")) + payload = { + "collectionName": name, + "outputFields": output_fields, + "filter": f"uid in {uids}", + } + rsp = self.vector_client.vector_query(payload) + assert rsp['code'] == 200 + res = rsp['data'] + logger.info(f"res: {len(res)}") + ids = [] + for r in res: + ids.append(r['id']) + logger.info(f"ids: {len(ids)}") + id_to_get = ids + # delete by id list + payload = { + "collectionName": name, + "id": id_to_get + } + client = self.vector_client + client.api_key = "invalid_api_key" + rsp = client.vector_delete(payload) + assert rsp['code'] == 1800 + + def test_delete_vector_with_invalid_collection_name(self): + """ + Delete a vector with an invalid collection name + """ + name = gen_collection_name() + self.name = name + self.init_collection(name, dim=128, nb=3000) + + # query data + # expr = f"id in {[i for i in range(10)]}".replace("[", "(").replace("]", ")") + expr = "id > 0" + payload = { + "collectionName": name, + "filter": expr, + "limit": 3000, + "offset": 0, + "outputFields": ["id", "uid"] + } + rsp = self.vector_client.vector_query(payload) + assert rsp['code'] == 200 + res = rsp['data'] + logger.info(f"res: {len(res)}") + id_list = [r['id'] for r in res] + delete_expr = f"id in {[i for i in id_list[:10]]}" + # query data before delete + payload = { + "collectionName": name, + "filter": delete_expr, + "limit": 3000, + "offset": 0, + "outputFields": ["id", "uid"] + } + rsp = self.vector_client.vector_query(payload) + assert rsp['code'] == 200 + res = rsp['data'] + logger.info(f"res: {len(res)}") + # delete data + payload = { + "collectionName": name + "_invalid", + "filter": delete_expr, + } + rsp = self.vector_client.vector_delete(payload) + assert rsp['code'] == 1 + + def test_delete_vector_with_non_primary_key(self): + """ + Delete a vector with a non-primary key, expect no data were deleted + """ + name = gen_collection_name() + self.name = name + self.init_collection(name, dim=128, nb=300) + expr = "uid > 0" + payload = { + "collectionName": name, + "filter": expr, + "limit": 3000, + "offset": 0, + "outputFields": ["id", "uid"] + } + rsp = self.vector_client.vector_query(payload) + assert rsp['code'] == 200 + res = rsp['data'] + logger.info(f"res: {len(res)}") + id_list = [r['uid'] for r in res] + delete_expr = f"uid in {[i for i in id_list[:10]]}" + # query data before delete + payload = { + "collectionName": name, + "filter": delete_expr, + "limit": 3000, + "offset": 0, + "outputFields": ["id", "uid"] + } + rsp = self.vector_client.vector_query(payload) + assert rsp['code'] == 200 + res = rsp['data'] + num_before_delete = len(res) + logger.info(f"res: {len(res)}") + # delete data + payload = { + "collectionName": name, + "filter": delete_expr, + } + rsp = self.vector_client.vector_delete(payload) + # query data after delete + payload = { + "collectionName": name, + "filter": delete_expr, + "limit": 3000, + "offset": 0, + "outputFields": ["id", "uid"] + } + time.sleep(1) + rsp = self.vector_client.vector_query(payload) + assert len(rsp["data"]) == num_before_delete diff --git a/tests/restful_client/utils/constant.py b/tests/restful_client/utils/constant.py new file mode 100644 index 0000000000..adeb3c8b2c --- /dev/null +++ b/tests/restful_client/utils/constant.py @@ -0,0 +1,2 @@ + +MAX_SUM_OFFSET_AND_LIMIT = 16384 diff --git a/tests/restful_client/utils/util_log.py b/tests/restful_client/utils/util_log.py index 4743f30b9d..fbd0f84f75 100644 --- a/tests/restful_client/utils/util_log.py +++ b/tests/restful_client/utils/util_log.py @@ -1,4 +1,5 @@ import logging +from loguru import logger as loguru_logger import sys from config.log_config import log_config @@ -54,4 +55,6 @@ log_debug = log_config.log_debug log_info = log_config.log_info log_err = log_config.log_err log_worker = log_config.log_worker -test_log = TestLog('ci_test', log_debug, log_info, log_err, log_worker).log +self_defined_log = TestLog('ci_test', log_debug, log_info, log_err, log_worker).log +loguru_log = loguru_logger +test_log = self_defined_log diff --git a/tests/restful_client/utils/util_wrapper.py b/tests/restful_client/utils/util_wrapper.py deleted file mode 100644 index d122054c68..0000000000 --- a/tests/restful_client/utils/util_wrapper.py +++ /dev/null @@ -1,38 +0,0 @@ -import time -from datetime import datetime -import functools -from utils.util_log import test_log as log - -DEFAULT_FMT = '[{start_time}] [{elapsed:0.8f}s] {collection_name} {func_name} -> {res!r}' - - -def trace(fmt=DEFAULT_FMT, prefix='test', flag=True): - def decorate(func): - @functools.wraps(func) - def inner_wrapper(*args, **kwargs): - # args[0] is an instance of ApiCollectionWrapper class - flag = args[0].active_trace - if flag: - start_time = datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%SZ') - t0 = time.perf_counter() - res, result = func(*args, **kwargs) - elapsed = time.perf_counter() - t0 - end_time = datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%SZ') - func_name = func.__name__ - collection_name = args[0].collection.name - # arg_lst = [repr(arg) for arg in args[1:]][:100] - # arg_lst.extend(f'{k}={v!r}' for k, v in kwargs.items()) - # arg_str = ', '.join(arg_lst)[:200] - - log_str = f"[{prefix}]" + fmt.format(**locals()) - # TODO: add report function in this place, like uploading to influxdb - # it is better a async way to do this, in case of blocking the request processing - log.info(log_str) - return res, result - else: - res, result = func(*args, **kwargs) - return res, result - - return inner_wrapper - - return decorate diff --git a/tests/restful_client/utils/utils.py b/tests/restful_client/utils/utils.py new file mode 100644 index 0000000000..06942c181b --- /dev/null +++ b/tests/restful_client/utils/utils.py @@ -0,0 +1,155 @@ +import random +import time +import random +import string +from faker import Faker +import numpy as np +from sklearn import preprocessing +import requests +from loguru import logger +import datetime + +fake = Faker() + + +def random_string(length=8): + letters = string.ascii_letters + return ''.join(random.choice(letters) for _ in range(length)) + + +def gen_collection_name(prefix="test_collection", length=8): + name = f'{prefix}_' + datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f") + random_string(length=length) + return name + +def admin_password(): + return "Milvus" + + +def invalid_cluster_name(): + res = [ + "demo" * 100, + "demo" + "!", + "demo" + "@", + ] + return res + + +def wait_cluster_be_ready(cluster_id, client, timeout=120): + t0 = time.time() + while True and time.time() - t0 < timeout: + rsp = client.cluster_describe(cluster_id) + if rsp['code'] == 200: + if rsp['data']['status'] == "RUNNING": + return time.time() - t0 + time.sleep(1) + logger.debug("wait cluster to be ready, cost time: %s" % (time.time() - t0)) + return -1 + + + + + +def gen_data_by_type(field): + data_type = field["type"] + if data_type == "bool": + return random.choice([True, False]) + if data_type == "int8": + return random.randint(-128, 127) + if data_type == "int16": + return random.randint(-32768, 32767) + if data_type == "int32": + return random.randint(-2147483648, 2147483647) + if data_type == "int64": + return random.randint(-9223372036854775808, 9223372036854775807) + if data_type == "float32": + return np.float64(random.random()) # Object of type float32 is not JSON serializable, so set it as float64 + if data_type == "float64": + return np.float64(random.random()) + if "varchar" in data_type: + length = int(data_type.split("(")[1].split(")")[0]) + return "".join([chr(random.randint(97, 122)) for _ in range(length)]) + if "floatVector" in data_type: + dim = int(data_type.split("(")[1].split(")")[0]) + return preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist() + return None + + +def get_data_by_fields(fields, nb): + # logger.info(f"fields: {fields}") + fields_not_auto_id = [] + for field in fields: + if not field.get("autoId", False): + fields_not_auto_id.append(field) + # logger.info(f"fields_not_auto_id: {fields_not_auto_id}") + data = [] + for i in range(nb): + tmp = {} + for field in fields_not_auto_id: + tmp[field["name"]] = gen_data_by_type(field) + data.append(tmp) + return data + + +def get_random_json_data(uid=None): + # gen random dict data + if uid is None: + uid = 0 + data = {"uid": uid, "name": fake.name(), "address": fake.address(), "text": fake.text(), "email": fake.email(), + "phone_number": fake.phone_number(), + "json": { + "name": fake.name(), + "address": fake.address() + } + } + for i in range(random.randint(1, 10)): + data["key" + str(random.randint(1, 100_000))] = "value" + str(random.randint(1, 100_000)) + return data + + +def get_data_by_payload(payload, nb=100): + dim = payload.get("dimension", 128) + vector_field = payload.get("vectorField", "vector") + data = [] + if nb == 1: + data = [{ + vector_field: preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist(), + **get_random_json_data() + + }] + else: + for i in range(nb): + data.append({ + vector_field: preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist(), + **get_random_json_data(uid=i) + }) + return data + + +def get_common_fields_by_data(data, exclude_fields=None): + fields = set() + if isinstance(data, dict): + data = [data] + if not isinstance(data, list): + raise Exception("data must be list or dict") + common_fields = set(data[0].keys()) + for d in data: + keys = set(d.keys()) + common_fields = common_fields.intersection(keys) + if exclude_fields is not None: + exclude_fields = set(exclude_fields) + common_fields = common_fields.difference(exclude_fields) + return list(common_fields) + + +def get_all_fields_by_data(data, exclude_fields=None): + fields = set() + for d in data: + keys = list(d.keys()) + fields.union(keys) + if exclude_fields is not None: + exclude_fields = set(exclude_fields) + fields = fields.difference(exclude_fields) + return list(fields) + + +