mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-02 03:48:37 +08:00
[test]Update restful api test (#25581)
Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>
This commit is contained in:
parent
7fec0d61cc
commit
eade5f9b7f
@ -1,17 +0,0 @@
|
|||||||
from decorest import GET, POST, DELETE
|
|
||||||
from decorest import HttpStatus, RestClient
|
|
||||||
from decorest import accept, body, content, endpoint, form
|
|
||||||
from decorest import header, multipart, on, query, stream, timeout
|
|
||||||
|
|
||||||
|
|
||||||
class Alias(RestClient):
|
|
||||||
|
|
||||||
def drop_alias():
|
|
||||||
pass
|
|
||||||
|
|
||||||
def alter_alias():
|
|
||||||
pass
|
|
||||||
|
|
||||||
def create_alias():
|
|
||||||
pass
|
|
||||||
|
|
@ -1,62 +0,0 @@
|
|||||||
import json
|
|
||||||
|
|
||||||
from decorest import GET, POST, DELETE
|
|
||||||
from decorest import HttpStatus, RestClient
|
|
||||||
from decorest import accept, body, content, endpoint, form
|
|
||||||
from decorest import header, multipart, on, query, stream, timeout
|
|
||||||
|
|
||||||
|
|
||||||
class Collection(RestClient):
|
|
||||||
|
|
||||||
@DELETE("collection")
|
|
||||||
@body("payload", lambda p: json.dumps(p))
|
|
||||||
@on(200, lambda r: r.json())
|
|
||||||
def drop_collection(self, payload):
|
|
||||||
"""Drop a collection"""
|
|
||||||
|
|
||||||
@GET("collection")
|
|
||||||
@body("payload", lambda p: json.dumps(p))
|
|
||||||
@on(200, lambda r: r.json())
|
|
||||||
def describe_collection(self, payload):
|
|
||||||
"""Describe a collection"""
|
|
||||||
|
|
||||||
@POST("collection")
|
|
||||||
@body("payload", lambda p: json.dumps(p))
|
|
||||||
@on(200, lambda r: r.json())
|
|
||||||
def create_collection(self, payload):
|
|
||||||
"""Create a collection"""
|
|
||||||
|
|
||||||
@GET("collection/existence")
|
|
||||||
@body("payload", lambda p: json.dumps(p))
|
|
||||||
@on(200, lambda r: r.json())
|
|
||||||
def has_collection(self, payload):
|
|
||||||
"""Check if a collection exists"""
|
|
||||||
|
|
||||||
@DELETE("collection/load")
|
|
||||||
@body("payload", lambda p: json.dumps(p))
|
|
||||||
@on(200, lambda r: r.json())
|
|
||||||
def release_collection(self, payload):
|
|
||||||
"""Release a collection"""
|
|
||||||
|
|
||||||
@POST("collection/load")
|
|
||||||
@body("payload", lambda p: json.dumps(p))
|
|
||||||
@on(200, lambda r: r.json())
|
|
||||||
def load_collection(self, payload):
|
|
||||||
"""Load a collection"""
|
|
||||||
|
|
||||||
@GET("collection/statistics")
|
|
||||||
@body("payload", lambda p: json.dumps(p))
|
|
||||||
@on(200, lambda r: r.json())
|
|
||||||
def get_collection_statistics(self, payload):
|
|
||||||
"""Get collection statistics"""
|
|
||||||
|
|
||||||
@GET("collections")
|
|
||||||
@body("payload", lambda p: json.dumps(p))
|
|
||||||
@on(200, lambda r: r.json())
|
|
||||||
def show_collections(self, payload):
|
|
||||||
"""Show collections"""
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
client = Collection("http://localhost:19121/api/v1")
|
|
||||||
print(client)
|
|
@ -1,19 +0,0 @@
|
|||||||
from decorest import GET, POST, DELETE
|
|
||||||
from decorest import HttpStatus, RestClient
|
|
||||||
from decorest import accept, body, content, endpoint, form
|
|
||||||
from decorest import header, multipart, on, query, stream, timeout
|
|
||||||
|
|
||||||
|
|
||||||
class Credential(RestClient):
|
|
||||||
|
|
||||||
def delete_credential():
|
|
||||||
pass
|
|
||||||
|
|
||||||
def update_credential():
|
|
||||||
pass
|
|
||||||
|
|
||||||
def create_credential():
|
|
||||||
pass
|
|
||||||
|
|
||||||
def list_credentials():
|
|
||||||
pass
|
|
@ -1,63 +0,0 @@
|
|||||||
import json
|
|
||||||
from decorest import GET, POST, DELETE
|
|
||||||
from decorest import HttpStatus, RestClient
|
|
||||||
from decorest import accept, body, content, endpoint, form
|
|
||||||
from decorest import header, multipart, on, query, stream, timeout
|
|
||||||
|
|
||||||
|
|
||||||
class Entity(RestClient):
|
|
||||||
|
|
||||||
@POST("distance")
|
|
||||||
@body("payload", lambda p: json.dumps(p))
|
|
||||||
@on(200, lambda r: r.json())
|
|
||||||
def calc_distance(self, payload):
|
|
||||||
""" Calculate distance between two points """
|
|
||||||
|
|
||||||
@DELETE("entities")
|
|
||||||
@body("payload", lambda p: json.dumps(p))
|
|
||||||
@on(200, lambda r: r.json())
|
|
||||||
def delete(self, payload):
|
|
||||||
"""delete entities"""
|
|
||||||
|
|
||||||
@POST("entities")
|
|
||||||
@body("payload", lambda p: json.dumps(p))
|
|
||||||
@on(200, lambda r: r.json())
|
|
||||||
def insert(self, payload):
|
|
||||||
"""insert entities"""
|
|
||||||
|
|
||||||
@POST("persist")
|
|
||||||
@body("payload", lambda p: json.dumps(p))
|
|
||||||
@on(200, lambda r: r.json())
|
|
||||||
def flush(self, payload):
|
|
||||||
"""flush entities"""
|
|
||||||
|
|
||||||
@POST("persist/segment-info")
|
|
||||||
@body("payload", lambda p: json.dumps(p))
|
|
||||||
@on(200, lambda r: r.json())
|
|
||||||
def get_persistent_segment_info(self, payload):
|
|
||||||
"""get persistent segment info"""
|
|
||||||
|
|
||||||
@POST("persist/state")
|
|
||||||
@body("payload", lambda p: json.dumps(p))
|
|
||||||
@on(200, lambda r: r.json())
|
|
||||||
def get_flush_state(self, payload):
|
|
||||||
"""get flush state"""
|
|
||||||
|
|
||||||
@POST("query")
|
|
||||||
@body("payload", lambda p: json.dumps(p))
|
|
||||||
@on(200, lambda r: r.json())
|
|
||||||
def query(self, payload):
|
|
||||||
"""query entities"""
|
|
||||||
|
|
||||||
@POST("query-segment-info")
|
|
||||||
@body("payload", lambda p: json.dumps(p))
|
|
||||||
@on(200, lambda r: r.json())
|
|
||||||
def get_query_segment_info(self, payload):
|
|
||||||
"""get query segment info"""
|
|
||||||
|
|
||||||
@POST("search")
|
|
||||||
@body("payload", lambda p: json.dumps(p))
|
|
||||||
@on(200, lambda r: r.json())
|
|
||||||
def search(self, payload):
|
|
||||||
"""search entities"""
|
|
||||||
|
|
@ -1,18 +0,0 @@
|
|||||||
from decorest import GET, POST, DELETE
|
|
||||||
from decorest import HttpStatus, RestClient
|
|
||||||
from decorest import accept, body, content, endpoint, form
|
|
||||||
from decorest import header, multipart, on, query, stream, timeout
|
|
||||||
|
|
||||||
class Import(RestClient):
|
|
||||||
|
|
||||||
def list_import_tasks():
|
|
||||||
pass
|
|
||||||
|
|
||||||
def exec_import():
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_import_state():
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1,38 +0,0 @@
|
|||||||
import json
|
|
||||||
from decorest import GET, POST, DELETE
|
|
||||||
from decorest import HttpStatus, RestClient
|
|
||||||
from decorest import accept, body, content, endpoint, form
|
|
||||||
from decorest import header, multipart, on, query, stream, timeout
|
|
||||||
|
|
||||||
|
|
||||||
class Index(RestClient):
|
|
||||||
|
|
||||||
@DELETE("/index")
|
|
||||||
@body("payload", lambda p: json.dumps(p))
|
|
||||||
@on(200, lambda r: r.json())
|
|
||||||
def drop_index(self, payload):
|
|
||||||
"""Drop an index"""
|
|
||||||
|
|
||||||
@GET("/index")
|
|
||||||
@body("payload", lambda p: json.dumps(p))
|
|
||||||
@on(200, lambda r: r.json())
|
|
||||||
def describe_index(self, payload):
|
|
||||||
"""Describe an index"""
|
|
||||||
|
|
||||||
@POST("index")
|
|
||||||
@body("payload", lambda p: json.dumps(p))
|
|
||||||
@on(200, lambda r: r.json())
|
|
||||||
def create_index(self, payload):
|
|
||||||
"""create index"""
|
|
||||||
|
|
||||||
@GET("index/progress")
|
|
||||||
@body("payload", lambda p: json.dumps(p))
|
|
||||||
@on(200, lambda r: r.json())
|
|
||||||
def get_index_build_progress(self, payload):
|
|
||||||
"""get index build progress"""
|
|
||||||
|
|
||||||
@GET("index/state")
|
|
||||||
@body("payload", lambda p: json.dumps(p))
|
|
||||||
@on(200, lambda r: r.json())
|
|
||||||
def get_index_state(self, payload):
|
|
||||||
"""get index state"""
|
|
@ -1,11 +0,0 @@
|
|||||||
from decorest import GET, POST, DELETE
|
|
||||||
from decorest import HttpStatus, RestClient
|
|
||||||
from decorest import accept, body, content, endpoint, form
|
|
||||||
from decorest import header, multipart, on, query, stream, timeout
|
|
||||||
|
|
||||||
|
|
||||||
class Metrics(RestClient):
|
|
||||||
|
|
||||||
def get_metrics():
|
|
||||||
pass
|
|
||||||
|
|
257
tests/restful_client/api/milvus.py
Normal file
257
tests/restful_client/api/milvus.py
Normal file
@ -0,0 +1,257 @@
|
|||||||
|
import json
|
||||||
|
import requests
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from utils.util_log import test_log as logger
|
||||||
|
|
||||||
|
|
||||||
|
def logger_request_response(response, url, tt, headers, data, str_data, str_response, method):
|
||||||
|
if len(data) > 2000:
|
||||||
|
data = data[:1000] + "..." + data[-1000:]
|
||||||
|
try:
|
||||||
|
if response.status_code == 200:
|
||||||
|
if ('code' in response.json() and response.json()["code"] == 200) or ('Code' in response.json() and response.json()["Code"] == 0):
|
||||||
|
logger.debug(
|
||||||
|
f"method: {method}, url: {url}, cost time: {tt}, header: {headers}, payload: {str_data}, response: {str_response}")
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"method: {method}, url: {url}, cost time: {tt}, header: {headers}, payload: {data}, response: {response.text}")
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"method: {method}, url: {url}, cost time: {tt}, header: {headers}, payload: {data}, response: {response.text}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
logger.error(
|
||||||
|
f"method: {method}, url: {url}, cost time: {tt}, header: {headers}, payload: {data}, response: {response.text}")
|
||||||
|
|
||||||
|
|
||||||
|
class Requests:
|
||||||
|
def __init__(self, url=None, api_key=None):
|
||||||
|
self.url = url
|
||||||
|
self.api_key = api_key
|
||||||
|
self.headers = {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
'Authorization': f'Bearer {self.api_key}',
|
||||||
|
'RequestId': str(uuid.uuid1())
|
||||||
|
}
|
||||||
|
|
||||||
|
def update_headers(self):
|
||||||
|
headers = {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
'Authorization': f'Bearer {self.api_key}',
|
||||||
|
'RequestId': str(uuid.uuid1())
|
||||||
|
}
|
||||||
|
return headers
|
||||||
|
|
||||||
|
def post(self, url, headers=None, data=None):
|
||||||
|
headers = headers if headers is not None else self.update_headers()
|
||||||
|
data = json.dumps(data)
|
||||||
|
str_data = data[:200] + '...' + data[-200:] if len(data) > 400 else data
|
||||||
|
t0 = time.time()
|
||||||
|
response = requests.post(url, headers=headers, data=data)
|
||||||
|
tt = time.time() - t0
|
||||||
|
str_response = response.text[:200] + '...' + response.text[-200:] if len(response.text) > 400 else response.text
|
||||||
|
logger_request_response(response, url, tt, headers, data, str_data, str_response, "post")
|
||||||
|
return response
|
||||||
|
|
||||||
|
def get(self, url, headers=None, params=None, data=None):
|
||||||
|
headers = headers if headers is not None else self.update_headers()
|
||||||
|
data = json.dumps(data)
|
||||||
|
str_data = data[:200] + '...' + data[-200:] if len(data) > 400 else data
|
||||||
|
t0 = time.time()
|
||||||
|
if data is None or data == "null":
|
||||||
|
response = requests.get(url, headers=headers, params=params)
|
||||||
|
else:
|
||||||
|
response = requests.get(url, headers=headers, params=params, data=data)
|
||||||
|
tt = time.time() - t0
|
||||||
|
str_response = response.text[:200] + '...' + response.text[-200:] if len(response.text) > 400 else response.text
|
||||||
|
logger_request_response(response, url, tt, headers, data, str_data, str_response, "get")
|
||||||
|
return response
|
||||||
|
|
||||||
|
def put(self, url, headers=None, data=None):
|
||||||
|
headers = headers if headers is not None else self.update_headers()
|
||||||
|
data = json.dumps(data)
|
||||||
|
str_data = data[:200] + '...' + data[-200:] if len(data) > 400 else data
|
||||||
|
t0 = time.time()
|
||||||
|
response = requests.put(url, headers=headers, data=data)
|
||||||
|
tt = time.time() - t0
|
||||||
|
str_response = response.text[:200] + '...' + response.text[-200:] if len(response.text) > 400 else response.text
|
||||||
|
logger_request_response(response, url, tt, headers, data, str_data, str_response, "put")
|
||||||
|
return response
|
||||||
|
|
||||||
|
def delete(self, url, headers=None, data=None):
|
||||||
|
headers = headers if headers is not None else self.update_headers()
|
||||||
|
data = json.dumps(data)
|
||||||
|
str_data = data[:200] + '...' + data[-200:] if len(data) > 400 else data
|
||||||
|
t0 = time.time()
|
||||||
|
response = requests.delete(url, headers=headers, data=data)
|
||||||
|
tt = time.time() - t0
|
||||||
|
str_response = response.text[:200] + '...' + response.text[-200:] if len(response.text) > 400 else response.text
|
||||||
|
logger_request_response(response, url, tt, headers, data, str_data, str_response, "delete")
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
class VectorClient(Requests):
|
||||||
|
def __init__(self, url, api_key, protocol="http"):
|
||||||
|
super().__init__(url, api_key)
|
||||||
|
self.protocol = protocol
|
||||||
|
self.url = url
|
||||||
|
self.api_key = api_key
|
||||||
|
self.db_name = None
|
||||||
|
self.headers = self.update_headers()
|
||||||
|
|
||||||
|
def update_headers(self):
|
||||||
|
headers = {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
'Authorization': f'Bearer {self.api_key}',
|
||||||
|
'RequestId': str(uuid.uuid1())
|
||||||
|
}
|
||||||
|
return headers
|
||||||
|
|
||||||
|
def vector_search(self, payload, db_name="default", timeout=10):
|
||||||
|
time.sleep(1)
|
||||||
|
url = f'{self.protocol}://{self.url}/vector/search'
|
||||||
|
if self.db_name is not None:
|
||||||
|
payload["dbName"] = self.db_name
|
||||||
|
if db_name != "default":
|
||||||
|
payload["dbName"] = db_name
|
||||||
|
response = self.post(url, headers=self.update_headers(), data=payload)
|
||||||
|
rsp = response.json()
|
||||||
|
if "data" in rsp and len(rsp["data"]) == 0:
|
||||||
|
t0 = time.time()
|
||||||
|
while time.time() - t0 < timeout:
|
||||||
|
response = self.post(url, headers=self.update_headers(), data=payload)
|
||||||
|
rsp = response.json()
|
||||||
|
if len(rsp["data"]) > 0:
|
||||||
|
break
|
||||||
|
time.sleep(1)
|
||||||
|
else:
|
||||||
|
response = self.post(url, headers=self.update_headers(), data=payload)
|
||||||
|
rsp = response.json()
|
||||||
|
if "data" in rsp and len(rsp["data"]) == 0:
|
||||||
|
logger.info(f"after {timeout}s, still no data")
|
||||||
|
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
def vector_query(self, payload, db_name="default", timeout=10):
|
||||||
|
time.sleep(1)
|
||||||
|
url = f'{self.protocol}://{self.url}/vector/query'
|
||||||
|
if self.db_name is not None:
|
||||||
|
payload["dbName"] = self.db_name
|
||||||
|
if db_name != "default":
|
||||||
|
payload["dbName"] = db_name
|
||||||
|
response = self.post(url, headers=self.update_headers(), data=payload)
|
||||||
|
rsp = response.json()
|
||||||
|
if "data" in rsp and len(rsp["data"]) == 0:
|
||||||
|
t0 = time.time()
|
||||||
|
while time.time() - t0 < timeout:
|
||||||
|
response = self.post(url, headers=self.update_headers(), data=payload)
|
||||||
|
rsp = response.json()
|
||||||
|
if len(rsp["data"]) > 0:
|
||||||
|
break
|
||||||
|
time.sleep(1)
|
||||||
|
else:
|
||||||
|
response = self.post(url, headers=self.update_headers(), data=payload)
|
||||||
|
rsp = response.json()
|
||||||
|
if "data" in rsp and len(rsp["data"]) == 0:
|
||||||
|
logger.info(f"after {timeout}s, still no data")
|
||||||
|
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
def vector_get(self, payload, db_name="default"):
|
||||||
|
time.sleep(1)
|
||||||
|
url = f'{self.protocol}://{self.url}/vector/get'
|
||||||
|
if self.db_name is not None:
|
||||||
|
payload["dbName"] = self.db_name
|
||||||
|
if db_name != "default":
|
||||||
|
payload["dbName"] = db_name
|
||||||
|
response = self.post(url, headers=self.update_headers(), data=payload)
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
def vector_delete(self, payload, db_name="default"):
|
||||||
|
url = f'{self.protocol}://{self.url}/vector/delete'
|
||||||
|
if self.db_name is not None:
|
||||||
|
payload["dbName"] = self.db_name
|
||||||
|
if db_name != "default":
|
||||||
|
payload["dbName"] = db_name
|
||||||
|
response = self.post(url, headers=self.update_headers(), data=payload)
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
def vector_insert(self, payload, db_name="default"):
|
||||||
|
url = f'{self.protocol}://{self.url}/vector/insert'
|
||||||
|
if self.db_name is not None:
|
||||||
|
payload["dbName"] = self.db_name
|
||||||
|
if db_name != "default":
|
||||||
|
payload["dbName"] = db_name
|
||||||
|
response = self.post(url, headers=self.update_headers(), data=payload)
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
|
class CollectionClient(Requests):
|
||||||
|
|
||||||
|
def __init__(self, url, api_key, protocol="http"):
|
||||||
|
super().__init__(url, api_key)
|
||||||
|
self.protocol = protocol
|
||||||
|
self.url = url
|
||||||
|
self.api_key = api_key
|
||||||
|
self.db_name = None
|
||||||
|
self.headers = self.update_headers()
|
||||||
|
|
||||||
|
def update_headers(self):
|
||||||
|
headers = {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
'Authorization': f'Bearer {self.api_key}',
|
||||||
|
'RequestId': str(uuid.uuid1())
|
||||||
|
}
|
||||||
|
return headers
|
||||||
|
|
||||||
|
def collection_list(self, db_name="default"):
|
||||||
|
url = f'{self.protocol}://{self.url}/vector/collections'
|
||||||
|
params = {}
|
||||||
|
if self.db_name is not None:
|
||||||
|
params = {
|
||||||
|
"dbName": self.db_name
|
||||||
|
}
|
||||||
|
if db_name != "default":
|
||||||
|
params = {
|
||||||
|
"dbName": db_name
|
||||||
|
}
|
||||||
|
response = self.get(url, headers=self.update_headers(), params=params)
|
||||||
|
res = response.json()
|
||||||
|
return res
|
||||||
|
|
||||||
|
def collection_create(self, payload, db_name="default"):
|
||||||
|
time.sleep(1) # wait for collection created and in case of rate limit
|
||||||
|
url = f'{self.protocol}://{self.url}/vector/collections/create'
|
||||||
|
if self.db_name is not None:
|
||||||
|
payload["dbName"] = self.db_name
|
||||||
|
if db_name != "default":
|
||||||
|
payload["dbName"] = db_name
|
||||||
|
response = self.post(url, headers=self.update_headers(), data=payload)
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
def collection_describe(self, collection_name, db_name="default"):
|
||||||
|
url = f'{self.protocol}://{self.url}/vector/collections/describe'
|
||||||
|
params = {"collectionName": collection_name}
|
||||||
|
if self.db_name is not None:
|
||||||
|
params = {
|
||||||
|
"collectionName": collection_name,
|
||||||
|
"dbName": self.db_name
|
||||||
|
}
|
||||||
|
if db_name != "default":
|
||||||
|
params = {
|
||||||
|
"collectionName": collection_name,
|
||||||
|
"dbName": db_name
|
||||||
|
}
|
||||||
|
response = self.get(url, headers=self.update_headers(), params=params)
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
def collection_drop(self, payload, db_name="default"):
|
||||||
|
time.sleep(1) # wait for collection drop and in case of rate limit
|
||||||
|
url = f'{self.protocol}://{self.url}/vector/collections/drop'
|
||||||
|
if self.db_name is not None:
|
||||||
|
payload["dbName"] = self.db_name
|
||||||
|
if db_name != "default":
|
||||||
|
payload["dbName"] = db_name
|
||||||
|
response = self.post(url, headers=self.update_headers(), data=payload)
|
||||||
|
return response.json()
|
@ -1,21 +0,0 @@
|
|||||||
from decorest import GET, POST, DELETE
|
|
||||||
from decorest import HttpStatus, RestClient
|
|
||||||
from decorest import accept, body, content, endpoint, form
|
|
||||||
from decorest import header, multipart, on, query, stream, timeout
|
|
||||||
|
|
||||||
class Ops(RestClient):
|
|
||||||
|
|
||||||
def manual_compaction():
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_compaction_plans():
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_compaction_state():
|
|
||||||
pass
|
|
||||||
|
|
||||||
def load_balance():
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_replicas():
|
|
||||||
pass
|
|
@ -1,27 +0,0 @@
|
|||||||
from decorest import GET, POST, DELETE
|
|
||||||
from decorest import HttpStatus, RestClient
|
|
||||||
from decorest import accept, body, content, endpoint, form
|
|
||||||
from decorest import header, multipart, on, query, stream, timeout
|
|
||||||
|
|
||||||
|
|
||||||
class Partition(RestClient):
|
|
||||||
def drop_partition():
|
|
||||||
pass
|
|
||||||
|
|
||||||
def create_partition():
|
|
||||||
pass
|
|
||||||
|
|
||||||
def has_partition():
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_partition_statistics():
|
|
||||||
pass
|
|
||||||
|
|
||||||
def show_partitions():
|
|
||||||
pass
|
|
||||||
|
|
||||||
def release_partition():
|
|
||||||
pass
|
|
||||||
|
|
||||||
def load_partition():
|
|
||||||
pass
|
|
@ -1,43 +0,0 @@
|
|||||||
from datetime import date, datetime
|
|
||||||
from typing import List, Union, Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel, UUID4, conlist
|
|
||||||
|
|
||||||
from pydantic_factories import ModelFactory
|
|
||||||
|
|
||||||
|
|
||||||
class Person(BaseModel):
|
|
||||||
def __init__(self, length):
|
|
||||||
super().__init__()
|
|
||||||
self.len = length
|
|
||||||
id: UUID4
|
|
||||||
name: str
|
|
||||||
hobbies: List[str]
|
|
||||||
age: Union[float, int]
|
|
||||||
birthday: Union[datetime, date]
|
|
||||||
|
|
||||||
|
|
||||||
class Pet(BaseModel):
|
|
||||||
name: str
|
|
||||||
age: int
|
|
||||||
|
|
||||||
|
|
||||||
class PetFactory(BaseModel):
|
|
||||||
name: str
|
|
||||||
pet: Pet
|
|
||||||
age: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
sample = {
|
|
||||||
"name": "John",
|
|
||||||
"pet": {
|
|
||||||
"name": "Fido",
|
|
||||||
"age": 3
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
result = PetFactory(**sample)
|
|
||||||
|
|
||||||
print(result)
|
|
||||||
|
|
@ -1,72 +0,0 @@
|
|||||||
from time import sleep
|
|
||||||
|
|
||||||
from decorest import HttpStatus, RestClient
|
|
||||||
from models.schema import CollectionSchema
|
|
||||||
from base.collection_service import CollectionService
|
|
||||||
from base.index_service import IndexService
|
|
||||||
from base.entity_service import EntityService
|
|
||||||
|
|
||||||
from utils.util_log import test_log as log
|
|
||||||
from common import common_func as cf
|
|
||||||
from common import common_type as ct
|
|
||||||
|
|
||||||
class Base:
|
|
||||||
"""init base class"""
|
|
||||||
|
|
||||||
endpoint = None
|
|
||||||
collection_service = None
|
|
||||||
index_service = None
|
|
||||||
entity_service = None
|
|
||||||
collection_name = None
|
|
||||||
collection_object_list = []
|
|
||||||
|
|
||||||
def setup_class(self):
|
|
||||||
log.info("setup class")
|
|
||||||
|
|
||||||
def teardown_class(self):
|
|
||||||
log.info("teardown class")
|
|
||||||
|
|
||||||
def setup_method(self, method):
|
|
||||||
log.info(("*" * 35) + " setup " + ("*" * 35))
|
|
||||||
log.info("[setup_method] Start setup test case %s." % method.__name__)
|
|
||||||
host = cf.param_info.param_host
|
|
||||||
port = cf.param_info.param_port
|
|
||||||
self.endpoint = "http://" + host + ":" + str(port) + "/api/v1"
|
|
||||||
self.collection_service = CollectionService(self.endpoint)
|
|
||||||
self.index_service = IndexService(self.endpoint)
|
|
||||||
self.entity_service = EntityService(self.endpoint)
|
|
||||||
|
|
||||||
def teardown_method(self, method):
|
|
||||||
res = self.collection_service.has_collection(collection_name=self.collection_name)
|
|
||||||
log.info(f"collection {self.collection_name} exists: {res}")
|
|
||||||
if res["value"] is True:
|
|
||||||
res = self.collection_service.drop_collection(self.collection_name)
|
|
||||||
log.info(f"drop collection {self.collection_name} res: {res}")
|
|
||||||
res = self.collection_service.show_collections()
|
|
||||||
all_collections = res["collection_names"]
|
|
||||||
union_collections = set(all_collections) & set(self.collection_object_list)
|
|
||||||
for collection in union_collections:
|
|
||||||
res = self.collection_service.drop_collection(collection)
|
|
||||||
log.info(f"drop collection {collection} res: {res}")
|
|
||||||
log.info("[teardown_method] Start teardown test case %s." % method.__name__)
|
|
||||||
log.info(("*" * 35) + " teardown " + ("*" * 35))
|
|
||||||
|
|
||||||
|
|
||||||
class TestBase(Base):
|
|
||||||
"""init test base class"""
|
|
||||||
|
|
||||||
def init_collection(self, name=None, schema=None):
|
|
||||||
collection_name = cf.gen_unique_str("test") if name is None else name
|
|
||||||
self.collection_name = collection_name
|
|
||||||
self.collection_object_list.append(collection_name)
|
|
||||||
if schema is None:
|
|
||||||
schema = cf.gen_default_schema(collection_name=collection_name)
|
|
||||||
# create collection
|
|
||||||
res = self.collection_service.create_collection(collection_name=collection_name, schema=schema)
|
|
||||||
|
|
||||||
log.info(f"create collection name: {collection_name} with schema: {schema}")
|
|
||||||
return collection_name, schema
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1,96 +0,0 @@
|
|||||||
from api.collection import Collection
|
|
||||||
from utils.util_log import test_log as log
|
|
||||||
from models import milvus
|
|
||||||
|
|
||||||
|
|
||||||
TIMEOUT = 30
|
|
||||||
|
|
||||||
|
|
||||||
class CollectionService:
|
|
||||||
|
|
||||||
def __init__(self, endpoint=None, timeout=None):
|
|
||||||
if timeout is None:
|
|
||||||
timeout = TIMEOUT
|
|
||||||
if endpoint is None:
|
|
||||||
endpoint = "http://localhost:9091/api/v1"
|
|
||||||
self._collection = Collection(endpoint=endpoint)
|
|
||||||
|
|
||||||
def create_collection(self, collection_name, consistency_level=1, schema=None, shards_num=2):
|
|
||||||
payload = {
|
|
||||||
"collection_name": collection_name,
|
|
||||||
"consistency_level": consistency_level,
|
|
||||||
"schema": schema,
|
|
||||||
"shards_num": shards_num
|
|
||||||
}
|
|
||||||
log.info(f"payload: {payload}")
|
|
||||||
# payload = milvus.CreateCollectionRequest(collection_name=collection_name,
|
|
||||||
# consistency_level=consistency_level,
|
|
||||||
# schema=schema,
|
|
||||||
# shards_num=shards_num)
|
|
||||||
# payload = payload.dict()
|
|
||||||
rsp = self._collection.create_collection(payload)
|
|
||||||
return rsp
|
|
||||||
|
|
||||||
def has_collection(self, collection_name=None, time_stamp=0):
|
|
||||||
payload = {
|
|
||||||
"collection_name": collection_name,
|
|
||||||
"time_stamp": time_stamp
|
|
||||||
}
|
|
||||||
# payload = milvus.HasCollectionRequest(collection_name=collection_name, time_stamp=time_stamp)
|
|
||||||
# payload = payload.dict()
|
|
||||||
return self._collection.has_collection(payload)
|
|
||||||
|
|
||||||
def drop_collection(self, collection_name):
|
|
||||||
payload = {
|
|
||||||
"collection_name": collection_name
|
|
||||||
}
|
|
||||||
# payload = milvus.DropCollectionRequest(collection_name=collection_name)
|
|
||||||
# payload = payload.dict()
|
|
||||||
return self._collection.drop_collection(payload)
|
|
||||||
|
|
||||||
def describe_collection(self, collection_name, collection_id=None, time_stamp=0):
|
|
||||||
payload = {
|
|
||||||
"collection_name": collection_name,
|
|
||||||
"collection_id": collection_id,
|
|
||||||
"time_stamp": time_stamp
|
|
||||||
}
|
|
||||||
# payload = milvus.DescribeCollectionRequest(collection_name=collection_name,
|
|
||||||
# collectionID=collection_id,
|
|
||||||
# time_stamp=time_stamp)
|
|
||||||
# payload = payload.dict()
|
|
||||||
return self._collection.describe_collection(payload)
|
|
||||||
|
|
||||||
def load_collection(self, collection_name, replica_number=1):
|
|
||||||
payload = {
|
|
||||||
"collection_name": collection_name,
|
|
||||||
"replica_number": replica_number
|
|
||||||
}
|
|
||||||
# payload = milvus.LoadCollectionRequest(collection_name=collection_name, replica_number=replica_number)
|
|
||||||
# payload = payload.dict()
|
|
||||||
return self._collection.load_collection(payload)
|
|
||||||
|
|
||||||
def release_collection(self, collection_name):
|
|
||||||
payload = {
|
|
||||||
"collection_name": collection_name
|
|
||||||
}
|
|
||||||
# payload = milvus.ReleaseCollectionRequest(collection_name=collection_name)
|
|
||||||
# payload = payload.dict()
|
|
||||||
return self._collection.release_collection(payload)
|
|
||||||
|
|
||||||
def get_collection_statistics(self, collection_name):
|
|
||||||
payload = {
|
|
||||||
"collection_name": collection_name
|
|
||||||
}
|
|
||||||
# payload = milvus.GetCollectionStatisticsRequest(collection_name=collection_name)
|
|
||||||
# payload = payload.dict()
|
|
||||||
return self._collection.get_collection_statistics(payload)
|
|
||||||
|
|
||||||
def show_collections(self, collection_names=None, type=0):
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"collection_names": collection_names,
|
|
||||||
"type": type
|
|
||||||
}
|
|
||||||
# payload = milvus.ShowCollectionsRequest(collection_names=collection_names, type=type)
|
|
||||||
# payload = payload.dict()
|
|
||||||
return self._collection.show_collections(payload)
|
|
@ -1,182 +0,0 @@
|
|||||||
from api.entity import Entity
|
|
||||||
from common import common_type as ct
|
|
||||||
from utils.util_log import test_log as log
|
|
||||||
from models import common, schema, milvus, server
|
|
||||||
|
|
||||||
TIMEOUT = 30
|
|
||||||
|
|
||||||
|
|
||||||
class EntityService:
|
|
||||||
|
|
||||||
def __init__(self, endpoint=None, timeout=None):
|
|
||||||
if timeout is None:
|
|
||||||
timeout = TIMEOUT
|
|
||||||
if endpoint is None:
|
|
||||||
endpoint = "http://localhost:9091/api/v1"
|
|
||||||
self._entity = Entity(endpoint=endpoint)
|
|
||||||
|
|
||||||
def calc_distance(self, base=None, op_left=None, op_right=None, params=None):
|
|
||||||
payload = {
|
|
||||||
"base": base,
|
|
||||||
"op_left": op_left,
|
|
||||||
"op_right": op_right,
|
|
||||||
"params": params
|
|
||||||
}
|
|
||||||
# payload = milvus.CalcDistanceRequest(base=base, op_left=op_left, op_right=op_right, params=params)
|
|
||||||
# payload = payload.dict()
|
|
||||||
return self._entity.calc_distance(payload)
|
|
||||||
|
|
||||||
def delete(self, base=None, collection_name=None, db_name=None, expr=None, hash_keys=None, partition_name=None):
|
|
||||||
payload = {
|
|
||||||
"base": base,
|
|
||||||
"collection_name": collection_name,
|
|
||||||
"db_name": db_name,
|
|
||||||
"expr": expr,
|
|
||||||
"hash_keys": hash_keys,
|
|
||||||
"partition_name": partition_name
|
|
||||||
}
|
|
||||||
# payload = server.DeleteRequest(base=base,
|
|
||||||
# collection_name=collection_name,
|
|
||||||
# db_name=db_name,
|
|
||||||
# expr=expr,
|
|
||||||
# hash_keys=hash_keys,
|
|
||||||
# partition_name=partition_name)
|
|
||||||
# payload = payload.dict()
|
|
||||||
return self._entity.delete(payload)
|
|
||||||
|
|
||||||
def insert(self, base=None, collection_name=None, db_name=None, fields_data=None, hash_keys=None, num_rows=None,
|
|
||||||
partition_name=None, check_task=True):
|
|
||||||
payload = {
|
|
||||||
"base": base,
|
|
||||||
"collection_name": collection_name,
|
|
||||||
"db_name": db_name,
|
|
||||||
"fields_data": fields_data,
|
|
||||||
"hash_keys": hash_keys,
|
|
||||||
"num_rows": num_rows,
|
|
||||||
"partition_name": partition_name
|
|
||||||
}
|
|
||||||
# payload = milvus.InsertRequest(base=base,
|
|
||||||
# collection_name=collection_name,
|
|
||||||
# db_name=db_name,
|
|
||||||
# fields_data=fields_data,
|
|
||||||
# hash_keys=hash_keys,
|
|
||||||
# num_rows=num_rows,
|
|
||||||
# partition_name=partition_name)
|
|
||||||
# payload = payload.dict()
|
|
||||||
rsp = self._entity.insert(payload)
|
|
||||||
if check_task:
|
|
||||||
assert rsp["status"] == {}
|
|
||||||
assert rsp["insert_cnt"] == num_rows
|
|
||||||
return rsp
|
|
||||||
|
|
||||||
def flush(self, base=None, collection_names=None, db_name=None, check_task=True):
|
|
||||||
payload = {
|
|
||||||
"base": base,
|
|
||||||
"collection_names": collection_names,
|
|
||||||
"db_name": db_name
|
|
||||||
}
|
|
||||||
# payload = server.FlushRequest(base=base,
|
|
||||||
# collection_names=collection_names,
|
|
||||||
# db_name=db_name)
|
|
||||||
# payload = payload.dict()
|
|
||||||
rsp = self._entity.flush(payload)
|
|
||||||
if check_task:
|
|
||||||
assert rsp["status"] == {}
|
|
||||||
|
|
||||||
def get_persistent_segment_info(self, base=None, collection_name=None, db_name=None):
|
|
||||||
payload = {
|
|
||||||
"base": base,
|
|
||||||
"collection_name": collection_name,
|
|
||||||
"db_name": db_name
|
|
||||||
}
|
|
||||||
# payload = server.GetPersistentSegmentInfoRequest(base=base,
|
|
||||||
# collection_name=collection_name,
|
|
||||||
# db_name=db_name)
|
|
||||||
# payload = payload.dict()
|
|
||||||
return self._entity.get_persistent_segment_info(payload)
|
|
||||||
|
|
||||||
def get_flush_state(self, segment_ids=None):
|
|
||||||
payload = {
|
|
||||||
"segment_ids": segment_ids
|
|
||||||
}
|
|
||||||
# payload = server.GetFlushStateRequest(segment_ids=segment_ids)
|
|
||||||
# payload = payload.dict()
|
|
||||||
return self._entity.get_flush_state(payload)
|
|
||||||
|
|
||||||
def query(self, base=None, collection_name=None, db_name=None, expr=None,
|
|
||||||
guarantee_timestamp=None, output_fields=None, partition_names=None, travel_timestamp=None,
|
|
||||||
check_task=True):
|
|
||||||
payload = {
|
|
||||||
"base": base,
|
|
||||||
"collection_name": collection_name,
|
|
||||||
"db_name": db_name,
|
|
||||||
"expr": expr,
|
|
||||||
"guarantee_timestamp": guarantee_timestamp,
|
|
||||||
"output_fields": output_fields,
|
|
||||||
"partition_names": partition_names,
|
|
||||||
"travel_timestamp": travel_timestamp
|
|
||||||
}
|
|
||||||
#
|
|
||||||
# payload = server.QueryRequest(base=base, collection_name=collection_name, db_name=db_name, expr=expr,
|
|
||||||
# guarantee_timestamp=guarantee_timestamp, output_fields=output_fields,
|
|
||||||
# partition_names=partition_names, travel_timestamp=travel_timestamp)
|
|
||||||
# payload = payload.dict()
|
|
||||||
rsp = self._entity.query(payload)
|
|
||||||
if check_task:
|
|
||||||
fields_data = rsp["fields_data"]
|
|
||||||
for field_data in fields_data:
|
|
||||||
if field_data["field_name"] in expr:
|
|
||||||
data = field_data["Field"]["Scalars"]["Data"]["LongData"]["data"]
|
|
||||||
for d in data:
|
|
||||||
s = expr.replace(field_data["field_name"], str(d))
|
|
||||||
assert eval(s) is True
|
|
||||||
return rsp
|
|
||||||
|
|
||||||
def get_query_segment_info(self, base=None, collection_name=None, db_name=None):
|
|
||||||
payload = {
|
|
||||||
"base": base,
|
|
||||||
"collection_name": collection_name,
|
|
||||||
"db_name": db_name
|
|
||||||
}
|
|
||||||
# payload = server.GetQuerySegmentInfoRequest(base=base,
|
|
||||||
# collection_name=collection_name,
|
|
||||||
# db_name=db_name)
|
|
||||||
# payload = payload.dict()
|
|
||||||
return self._entity.get_query_segment_info(payload)
|
|
||||||
|
|
||||||
def search(self, base=None, collection_name=None, vectors=None, db_name=None, dsl=None,
|
|
||||||
output_fields=None, dsl_type=1,
|
|
||||||
guarantee_timestamp=None, partition_names=None, placeholder_group=None,
|
|
||||||
search_params=None, travel_timestamp=None, check_task=True):
|
|
||||||
payload = {
|
|
||||||
"base": base,
|
|
||||||
"collection_name": collection_name,
|
|
||||||
"output_fields": output_fields,
|
|
||||||
"vectors": vectors,
|
|
||||||
"db_name": db_name,
|
|
||||||
"dsl": dsl,
|
|
||||||
"dsl_type": dsl_type,
|
|
||||||
"guarantee_timestamp": guarantee_timestamp,
|
|
||||||
"partition_names": partition_names,
|
|
||||||
"placeholder_group": placeholder_group,
|
|
||||||
"search_params": search_params,
|
|
||||||
"travel_timestamp": travel_timestamp
|
|
||||||
}
|
|
||||||
# payload = server.SearchRequest(base=base, collection_name=collection_name, db_name=db_name, dsl=dsl,
|
|
||||||
# dsl_type=dsl_type, guarantee_timestamp=guarantee_timestamp,
|
|
||||||
# partition_names=partition_names, placeholder_group=placeholder_group,
|
|
||||||
# search_params=search_params, travel_timestamp=travel_timestamp)
|
|
||||||
# payload = payload.dict()
|
|
||||||
rsp = self._entity.search(payload)
|
|
||||||
if check_task:
|
|
||||||
assert rsp["status"] == {}
|
|
||||||
assert rsp["results"]["num_queries"] == len(vectors)
|
|
||||||
assert len(rsp["results"]["ids"]["IdField"]["IntId"]["data"]) == sum(rsp["results"]["topks"])
|
|
||||||
return rsp
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
41
tests/restful_client/base/error_code_message.py
Normal file
41
tests/restful_client/base/error_code_message.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class BaseError(Enum):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class VectorInsertError(BaseError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class VectorSearchError(BaseError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class VectorGetError(BaseError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class VectorQueryError(BaseError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class VectorDeleteError(BaseError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CollectionListError(BaseError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CollectionCreateError(BaseError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CollectionDropError(BaseError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CollectionDescribeError(BaseError):
|
||||||
|
pass
|
@ -1,57 +0,0 @@
|
|||||||
from api.index import Index
|
|
||||||
from models import common, schema, milvus, server
|
|
||||||
|
|
||||||
TIMEOUT = 30
|
|
||||||
|
|
||||||
|
|
||||||
class IndexService:
|
|
||||||
|
|
||||||
def __init__(self, endpoint=None, timeout=None):
|
|
||||||
if timeout is None:
|
|
||||||
timeout = TIMEOUT
|
|
||||||
if endpoint is None:
|
|
||||||
endpoint = "http://localhost:9091/api/v1"
|
|
||||||
self._index = Index(endpoint=endpoint)
|
|
||||||
|
|
||||||
def drop_index(self, base, collection_name, db_name, field_name, index_name):
|
|
||||||
payload = server.DropIndexRequest(base=base, collection_name=collection_name,
|
|
||||||
db_name=db_name, field_name=field_name, index_name=index_name)
|
|
||||||
payload = payload.dict()
|
|
||||||
return self._index.drop_index(payload)
|
|
||||||
|
|
||||||
def describe_index(self, base, collection_name, db_name, field_name, index_name):
|
|
||||||
payload = server.DescribeIndexRequest(base=base, collection_name=collection_name,
|
|
||||||
db_name=db_name, field_name=field_name, index_name=index_name)
|
|
||||||
payload = payload.dict()
|
|
||||||
return self._index.describe_index(payload)
|
|
||||||
|
|
||||||
def create_index(self, base=None, collection_name=None, db_name=None, extra_params=None,
|
|
||||||
field_name=None, index_name=None):
|
|
||||||
payload = {
|
|
||||||
"base": base,
|
|
||||||
"collection_name": collection_name,
|
|
||||||
"db_name": db_name,
|
|
||||||
"extra_params": extra_params,
|
|
||||||
"field_name": field_name,
|
|
||||||
"index_name": index_name
|
|
||||||
}
|
|
||||||
# payload = server.CreateIndexRequest(base=base, collection_name=collection_name, db_name=db_name,
|
|
||||||
# extra_params=extra_params, field_name=field_name, index_name=index_name)
|
|
||||||
# payload = payload.dict()
|
|
||||||
return self._index.create_index(payload)
|
|
||||||
|
|
||||||
def get_index_build_progress(self, base, collection_name, db_name, field_name, index_name):
|
|
||||||
payload = server.GetIndexBuildProgressRequest(base=base, collection_name=collection_name,
|
|
||||||
db_name=db_name, field_name=field_name, index_name=index_name)
|
|
||||||
payload = payload.dict()
|
|
||||||
return self._index.get_index_build_progress(payload)
|
|
||||||
|
|
||||||
def get_index_state(self, base, collection_name, db_name, field_name, index_name):
|
|
||||||
payload = server.GetIndexStateRequest(base=base, collection_name=collection_name,
|
|
||||||
db_name=db_name, field_name=field_name, index_name=index_name)
|
|
||||||
payload = payload.dict()
|
|
||||||
return self._index.get_index_state(payload)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
114
tests/restful_client/base/testbase.py
Normal file
114
tests/restful_client/base/testbase.py
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
import json
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import time
|
||||||
|
from pymilvus import connections, db
|
||||||
|
from utils.util_log import test_log as logger
|
||||||
|
from api.milvus import VectorClient, CollectionClient
|
||||||
|
from utils.utils import get_data_by_payload
|
||||||
|
|
||||||
|
|
||||||
|
def get_config():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Base:
|
||||||
|
name = None
|
||||||
|
host = None
|
||||||
|
port = None
|
||||||
|
url = None
|
||||||
|
api_key = None
|
||||||
|
username = None
|
||||||
|
password = None
|
||||||
|
invalid_api_key = None
|
||||||
|
vector_client = None
|
||||||
|
collection_client = None
|
||||||
|
|
||||||
|
|
||||||
|
class TestBase(Base):
|
||||||
|
|
||||||
|
def teardown_method(self):
|
||||||
|
self.collection_client.api_key = self.api_key
|
||||||
|
all_collections = self.collection_client.collection_list()['data']
|
||||||
|
if self.name in all_collections:
|
||||||
|
logger.info(f"collection {self.name} exist, drop it")
|
||||||
|
payload = {
|
||||||
|
"collectionName": self.name,
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
rsp = self.collection_client.collection_drop(payload)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
|
def init_client(self, host, port, username, password):
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.url = f"{host}:{port}/v1"
|
||||||
|
self.username = username
|
||||||
|
self.password = password
|
||||||
|
self.api_key = f"{self.username}:{self.password}"
|
||||||
|
self.invalid_api_key = "invalid_token"
|
||||||
|
self.vector_client = VectorClient(self.url, self.api_key)
|
||||||
|
self.collection_client = CollectionClient(self.url, self.api_key)
|
||||||
|
|
||||||
|
def init_collection(self, collection_name, pk_field="id", metric_type="L2", dim=128, nb=100, batch_size=1000):
|
||||||
|
# create collection
|
||||||
|
schema_payload = {
|
||||||
|
"collectionName": collection_name,
|
||||||
|
"dimension": dim,
|
||||||
|
"metricType": metric_type,
|
||||||
|
"description": "test collection",
|
||||||
|
"primaryField": pk_field,
|
||||||
|
"vectorField": "vector",
|
||||||
|
}
|
||||||
|
rsp = self.collection_client.collection_create(schema_payload)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
self.wait_collection_load_completed(collection_name)
|
||||||
|
batch_size = batch_size
|
||||||
|
batch = nb // batch_size
|
||||||
|
# in case of nb < batch_size
|
||||||
|
if batch == 0:
|
||||||
|
batch = 1
|
||||||
|
batch_size = nb
|
||||||
|
data = []
|
||||||
|
for i in range(batch):
|
||||||
|
nb = batch_size
|
||||||
|
data = get_data_by_payload(schema_payload, nb)
|
||||||
|
payload = {
|
||||||
|
"collectionName": collection_name,
|
||||||
|
"data": data
|
||||||
|
}
|
||||||
|
body_size = sys.getsizeof(json.dumps(payload))
|
||||||
|
logger.info(f"body size: {body_size / 1024 / 1024} MB")
|
||||||
|
rsp = self.vector_client.vector_insert(payload)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
return schema_payload, data
|
||||||
|
|
||||||
|
def wait_collection_load_completed(self, name):
|
||||||
|
t0 = time.time()
|
||||||
|
timeout = 60
|
||||||
|
while True and time.time() - t0 < timeout:
|
||||||
|
rsp = self.collection_client.collection_describe(name)
|
||||||
|
if "data" in rsp and "load" in rsp["data"] and rsp["data"]["load"] == "LoadStateLoaded":
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
def create_database(self, db_name="default"):
|
||||||
|
connections.connect(host=self.host, port=self.port)
|
||||||
|
all_db = db.list_database()
|
||||||
|
logger.info(f"all database: {all_db}")
|
||||||
|
if db_name not in all_db:
|
||||||
|
logger.info(f"create database: {db_name}")
|
||||||
|
try:
|
||||||
|
db.create_database(db_name=db_name)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
|
||||||
|
def update_database(self, db_name="default"):
|
||||||
|
self.create_database(db_name=db_name)
|
||||||
|
self.collection_client.db_name = db_name
|
||||||
|
self.vector_client.db_name = db_name
|
||||||
|
|
@ -1,27 +0,0 @@
|
|||||||
|
|
||||||
class CheckTasks:
|
|
||||||
""" The name of the method used to check the result """
|
|
||||||
check_nothing = "check_nothing"
|
|
||||||
err_res = "error_response"
|
|
||||||
ccr = "check_connection_result"
|
|
||||||
check_collection_property = "check_collection_property"
|
|
||||||
check_partition_property = "check_partition_property"
|
|
||||||
check_search_results = "check_search_results"
|
|
||||||
check_query_results = "check_query_results"
|
|
||||||
check_query_empty = "check_query_empty" # verify that query result is empty
|
|
||||||
check_query_not_empty = "check_query_not_empty"
|
|
||||||
check_distance = "check_distance"
|
|
||||||
check_delete_compact = "check_delete_compact"
|
|
||||||
check_merge_compact = "check_merge_compact"
|
|
||||||
check_role_property = "check_role_property"
|
|
||||||
check_permission_deny = "check_permission_deny"
|
|
||||||
check_value_equal = "check_value_equal"
|
|
||||||
|
|
||||||
|
|
||||||
class ResponseChecker:
|
|
||||||
|
|
||||||
def __init__(self, check_task, check_items):
|
|
||||||
self.check_task = check_task
|
|
||||||
self.check_items = check_items
|
|
||||||
|
|
||||||
|
|
@ -1,21 +0,0 @@
|
|||||||
from utils.util_log import test_log as log
|
|
||||||
|
|
||||||
|
|
||||||
def ip_check(ip):
|
|
||||||
if ip == "localhost":
|
|
||||||
return True
|
|
||||||
|
|
||||||
if not isinstance(ip, str):
|
|
||||||
log.error("[IP_CHECK] IP(%s) is not a string." % ip)
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def number_check(num):
|
|
||||||
if str(num).isdigit():
|
|
||||||
return True
|
|
||||||
|
|
||||||
else:
|
|
||||||
log.error("[NUMBER_CHECK] Number(%s) is not a numbers." % num)
|
|
||||||
return False
|
|
@ -1,292 +0,0 @@
|
|||||||
import json
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
import string
|
|
||||||
import numpy as np
|
|
||||||
from enum import Enum
|
|
||||||
from common import common_type as ct
|
|
||||||
from utils.util_log import test_log as log
|
|
||||||
|
|
||||||
|
|
||||||
class ParamInfo:
|
|
||||||
def __init__(self):
|
|
||||||
self.param_host = ""
|
|
||||||
self.param_port = ""
|
|
||||||
|
|
||||||
def prepare_param_info(self, host, http_port):
|
|
||||||
self.param_host = host
|
|
||||||
self.param_port = http_port
|
|
||||||
|
|
||||||
|
|
||||||
param_info = ParamInfo()
|
|
||||||
|
|
||||||
|
|
||||||
class DataType(Enum):
|
|
||||||
Bool: 1
|
|
||||||
Int8: 2
|
|
||||||
Int16: 3
|
|
||||||
Int32: 4
|
|
||||||
Int64: 5
|
|
||||||
Float: 10
|
|
||||||
Double: 11
|
|
||||||
String: 20
|
|
||||||
VarChar: 21
|
|
||||||
BinaryVector: 100
|
|
||||||
FloatVector: 101
|
|
||||||
|
|
||||||
|
|
||||||
def gen_unique_str(str_value=None):
|
|
||||||
prefix = "".join(random.choice(string.ascii_letters + string.digits) for _ in range(8))
|
|
||||||
return "test_" + prefix if str_value is None else str_value + "_" + prefix
|
|
||||||
|
|
||||||
|
|
||||||
def gen_field(name=ct.default_bool_field_name, description=ct.default_desc, type_params=None, index_params=None,
|
|
||||||
data_type="Int64", is_primary_key=False, auto_id=False, dim=128, max_length=256):
|
|
||||||
data_type_map = {
|
|
||||||
"Bool": 1,
|
|
||||||
"Int8": 2,
|
|
||||||
"Int16": 3,
|
|
||||||
"Int32": 4,
|
|
||||||
"Int64": 5,
|
|
||||||
"Float": 10,
|
|
||||||
"Double": 11,
|
|
||||||
"String": 20,
|
|
||||||
"VarChar": 21,
|
|
||||||
"BinaryVector": 100,
|
|
||||||
"FloatVector": 101,
|
|
||||||
}
|
|
||||||
if data_type == "Int64":
|
|
||||||
is_primary_key = True
|
|
||||||
auto_id = True
|
|
||||||
if type_params is None:
|
|
||||||
type_params = []
|
|
||||||
if index_params is None:
|
|
||||||
index_params = []
|
|
||||||
if data_type in ["FloatVector", "BinaryVector"]:
|
|
||||||
type_params = [{"key": "dim", "value": str(dim)}]
|
|
||||||
if data_type in ["String", "VarChar"]:
|
|
||||||
type_params = [{"key": "max_length", "value": str(dim)}]
|
|
||||||
return {
|
|
||||||
"name": name,
|
|
||||||
"description": description,
|
|
||||||
"data_type": data_type_map.get(data_type, 0),
|
|
||||||
"type_params": type_params,
|
|
||||||
"index_params": index_params,
|
|
||||||
"is_primary_key": is_primary_key,
|
|
||||||
"auto_id": auto_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def gen_schema(name, fields, description=ct.default_desc, auto_id=False):
|
|
||||||
return {
|
|
||||||
"name": name,
|
|
||||||
"description": description,
|
|
||||||
"auto_id": auto_id,
|
|
||||||
"fields": fields,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def gen_default_schema(data_types=None, dim=ct.default_dim, collection_name=None):
|
|
||||||
if data_types is None:
|
|
||||||
data_types = ["Int64", "Float", "VarChar", "FloatVector"]
|
|
||||||
fields = []
|
|
||||||
for data_type in data_types:
|
|
||||||
if data_type in ["FloatVector", "BinaryVector"]:
|
|
||||||
fields.append(gen_field(name=data_type, data_type=data_type, type_params=[{"key": "dim", "value": dim}]))
|
|
||||||
else:
|
|
||||||
fields.append(gen_field(name=data_type, data_type=data_type))
|
|
||||||
return {
|
|
||||||
"autoID": True,
|
|
||||||
"fields": fields,
|
|
||||||
"description": ct.default_desc,
|
|
||||||
"name": collection_name,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def gen_fields_data(schema=None, nb=ct.default_nb,):
|
|
||||||
if schema is None:
|
|
||||||
schema = gen_default_schema()
|
|
||||||
fields = schema["fields"]
|
|
||||||
fields_data = []
|
|
||||||
for field in fields:
|
|
||||||
if field["data_type"] == 1:
|
|
||||||
fields_data.append([random.choice([True, False]) for i in range(nb)])
|
|
||||||
elif field["data_type"] == 2:
|
|
||||||
fields_data.append([i for i in range(nb)])
|
|
||||||
elif field["data_type"] == 3:
|
|
||||||
fields_data.append([i for i in range(nb)])
|
|
||||||
elif field["data_type"] == 4:
|
|
||||||
fields_data.append([i for i in range(nb)])
|
|
||||||
elif field["data_type"] == 5:
|
|
||||||
fields_data.append([i for i in range(nb)])
|
|
||||||
elif field["data_type"] == 10:
|
|
||||||
fields_data.append([np.float64(i) for i in range(nb)]) # json not support float32
|
|
||||||
elif field["data_type"] == 11:
|
|
||||||
fields_data.append([np.float64(i) for i in range(nb)])
|
|
||||||
elif field["data_type"] == 20:
|
|
||||||
fields_data.append([gen_unique_str((str(i))) for i in range(nb)])
|
|
||||||
elif field["data_type"] == 21:
|
|
||||||
fields_data.append([gen_unique_str(str(i)) for i in range(nb)])
|
|
||||||
elif field["data_type"] == 100:
|
|
||||||
dim = ct.default_dim
|
|
||||||
for k, v in field["type_params"]:
|
|
||||||
if k == "dim":
|
|
||||||
dim = int(v)
|
|
||||||
break
|
|
||||||
fields_data.append(gen_binary_vectors(nb, dim))
|
|
||||||
elif field["data_type"] == 101:
|
|
||||||
dim = ct.default_dim
|
|
||||||
for k, v in field["type_params"]:
|
|
||||||
if k == "dim":
|
|
||||||
dim = int(v)
|
|
||||||
break
|
|
||||||
fields_data.append(gen_float_vectors(nb, dim))
|
|
||||||
else:
|
|
||||||
log.error("Unknown data type.")
|
|
||||||
fields_data_body = []
|
|
||||||
for i, field in enumerate(fields):
|
|
||||||
fields_data_body.append({
|
|
||||||
"field_name": field["name"],
|
|
||||||
"type": field["data_type"],
|
|
||||||
"field": fields_data[i],
|
|
||||||
})
|
|
||||||
return fields_data_body
|
|
||||||
|
|
||||||
|
|
||||||
def get_vector_field(schema):
|
|
||||||
for field in schema["fields"]:
|
|
||||||
if field["data_type"] in [100, 101]:
|
|
||||||
return field["name"]
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def get_varchar_field(schema):
|
|
||||||
for field in schema["fields"]:
|
|
||||||
if field["data_type"] == 21:
|
|
||||||
return field["name"]
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def gen_vectors(nq=None, schema=None):
|
|
||||||
if nq is None:
|
|
||||||
nq = ct.default_nq
|
|
||||||
dim = ct.default_dim
|
|
||||||
data_type = 101
|
|
||||||
for field in schema["fields"]:
|
|
||||||
if field["data_type"] in [100, 101]:
|
|
||||||
dim = ct.default_dim
|
|
||||||
data_type = field["data_type"]
|
|
||||||
for k, v in field["type_params"]:
|
|
||||||
if k == "dim":
|
|
||||||
dim = int(v)
|
|
||||||
break
|
|
||||||
if data_type == 100:
|
|
||||||
return gen_binary_vectors(nq, dim)
|
|
||||||
if data_type == 101:
|
|
||||||
return gen_float_vectors(nq, dim)
|
|
||||||
|
|
||||||
|
|
||||||
def gen_float_vectors(nb, dim):
|
|
||||||
return [[np.float64(random.uniform(-1.0, 1.0)) for _ in range(dim)] for _ in range(nb)] # json not support float32
|
|
||||||
|
|
||||||
|
|
||||||
def gen_binary_vectors(nb, dim):
|
|
||||||
raw_vectors = []
|
|
||||||
binary_vectors = []
|
|
||||||
for _ in range(nb):
|
|
||||||
raw_vector = [random.randint(0, 1) for _ in range(dim)]
|
|
||||||
raw_vectors.append(raw_vector)
|
|
||||||
# packs a binary-valued array into bits in a unit8 array, and bytes array_of_ints
|
|
||||||
binary_vectors.append(bytes(np.packbits(raw_vector, axis=-1).tolist()))
|
|
||||||
return binary_vectors
|
|
||||||
|
|
||||||
|
|
||||||
def gen_index_params(index_type=None):
|
|
||||||
if index_type is None:
|
|
||||||
index_params = ct.default_index_params
|
|
||||||
else:
|
|
||||||
index_params = ct.all_index_params_map[index_type]
|
|
||||||
extra_params = []
|
|
||||||
for k, v in index_params.items():
|
|
||||||
item = {"key": k, "value": json.dumps(v) if isinstance(v, dict) else str(v)}
|
|
||||||
extra_params.append(item)
|
|
||||||
return extra_params
|
|
||||||
|
|
||||||
def gen_search_param_by_index_type(index_type, metric_type="L2"):
|
|
||||||
search_params = []
|
|
||||||
if index_type in ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_PQ"]:
|
|
||||||
for nprobe in [10]:
|
|
||||||
ivf_search_params = {"metric_type": metric_type, "params": {"nprobe": nprobe}}
|
|
||||||
search_params.append(ivf_search_params)
|
|
||||||
elif index_type in ["BIN_FLAT", "BIN_IVF_FLAT"]:
|
|
||||||
for nprobe in [10]:
|
|
||||||
bin_search_params = {"metric_type": "HAMMING", "params": {"nprobe": nprobe}}
|
|
||||||
search_params.append(bin_search_params)
|
|
||||||
elif index_type in ["HNSW"]:
|
|
||||||
for ef in [64]:
|
|
||||||
hnsw_search_param = {"metric_type": metric_type, "params": {"ef": ef}}
|
|
||||||
search_params.append(hnsw_search_param)
|
|
||||||
elif index_type == "ANNOY":
|
|
||||||
for search_k in [1000]:
|
|
||||||
annoy_search_param = {"metric_type": metric_type, "params": {"search_k": search_k}}
|
|
||||||
search_params.append(annoy_search_param)
|
|
||||||
else:
|
|
||||||
log.info("Invalid index_type.")
|
|
||||||
raise Exception("Invalid index_type.")
|
|
||||||
return search_params
|
|
||||||
|
|
||||||
|
|
||||||
def gen_search_params(index_type=None, anns_field=ct.default_float_vec_field_name,
|
|
||||||
topk=ct.default_top_k):
|
|
||||||
if index_type is None:
|
|
||||||
search_params = gen_search_param_by_index_type(ct.default_index_type)[0]
|
|
||||||
else:
|
|
||||||
search_params = gen_search_param_by_index_type(index_type)[0]
|
|
||||||
extra_params = []
|
|
||||||
for k, v in search_params.items():
|
|
||||||
item = {"key": k, "value": json.dumps(v) if isinstance(v, dict) else str(v)}
|
|
||||||
extra_params.append(item)
|
|
||||||
extra_params.append({"key": "anns_field", "value": anns_field})
|
|
||||||
extra_params.append({"key": "topk", "value": str(topk)})
|
|
||||||
return extra_params
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def gen_search_vectors(dim, nb, is_binary=False):
|
|
||||||
if is_binary:
|
|
||||||
return gen_binary_vectors(nb, dim)
|
|
||||||
return gen_float_vectors(nb, dim)
|
|
||||||
|
|
||||||
|
|
||||||
def modify_file(file_path_list, is_modify=False, input_content=""):
|
|
||||||
"""
|
|
||||||
file_path_list : file list -> list[<file_path>]
|
|
||||||
is_modify : does the file need to be reset
|
|
||||||
input_content :the content that need to insert to the file
|
|
||||||
"""
|
|
||||||
if not isinstance(file_path_list, list):
|
|
||||||
log.error("[modify_file] file is not a list.")
|
|
||||||
|
|
||||||
for file_path in file_path_list:
|
|
||||||
folder_path, file_name = os.path.split(file_path)
|
|
||||||
if not os.path.isdir(folder_path):
|
|
||||||
log.debug("[modify_file] folder(%s) is not exist." % folder_path)
|
|
||||||
os.makedirs(folder_path)
|
|
||||||
|
|
||||||
if not os.path.isfile(file_path):
|
|
||||||
log.error("[modify_file] file(%s) is not exist." % file_path)
|
|
||||||
else:
|
|
||||||
if is_modify is True:
|
|
||||||
log.debug("[modify_file] start modifying file(%s)..." % file_path)
|
|
||||||
with open(file_path, "r+") as f:
|
|
||||||
f.seek(0)
|
|
||||||
f.truncate()
|
|
||||||
f.write(input_content)
|
|
||||||
f.close()
|
|
||||||
log.info("[modify_file] file(%s) modification is complete." % file_path_list)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
a = gen_binary_vectors(10, 128)
|
|
||||||
print(a)
|
|
@ -1,80 +0,0 @@
|
|||||||
""" Initialized parameters """
|
|
||||||
port = 19530
|
|
||||||
epsilon = 0.000001
|
|
||||||
namespace = "milvus"
|
|
||||||
default_flush_interval = 1
|
|
||||||
big_flush_interval = 1000
|
|
||||||
default_drop_interval = 3
|
|
||||||
default_dim = 128
|
|
||||||
default_nb = 3000
|
|
||||||
default_nb_medium = 5000
|
|
||||||
default_top_k = 10
|
|
||||||
default_nq = 2
|
|
||||||
default_limit = 10
|
|
||||||
default_search_params = {"metric_type": "L2", "params": {"nprobe": 10}}
|
|
||||||
default_search_ip_params = {"metric_type": "IP", "params": {"nprobe": 10}}
|
|
||||||
default_search_binary_params = {"metric_type": "JACCARD", "params": {"nprobe": 10}}
|
|
||||||
default_index_type = "HNSW"
|
|
||||||
default_index_params = {"index_type": "HNSW", "params": {"M": 48, "efConstruction": 500}, "metric_type": "L2"}
|
|
||||||
default_varchar_index = {}
|
|
||||||
default_binary_index = {"index_type": "BIN_IVF_FLAT", "params": {"nlist": 128}, "metric_type": "JACCARD"}
|
|
||||||
default_diskann_index = {"index_type": "DISKANN", "metric_type": "L2", "params": {}}
|
|
||||||
default_diskann_search_params = {"metric_type": "L2", "params": {"search_list": 30}}
|
|
||||||
max_top_k = 16384
|
|
||||||
max_partition_num = 4096 # 256
|
|
||||||
default_segment_row_limit = 1000
|
|
||||||
default_server_segment_row_limit = 1024 * 512
|
|
||||||
default_alias = "default"
|
|
||||||
default_user = "root"
|
|
||||||
default_password = "Milvus"
|
|
||||||
default_bool_field_name = "Bool"
|
|
||||||
default_int8_field_name = "Int8"
|
|
||||||
default_int16_field_name = "Int16"
|
|
||||||
default_int32_field_name = "Int32"
|
|
||||||
default_int64_field_name = "Int64"
|
|
||||||
default_float_field_name = "Float"
|
|
||||||
default_double_field_name = "Double"
|
|
||||||
default_string_field_name = "Varchar"
|
|
||||||
default_float_vec_field_name = "FloatVector"
|
|
||||||
another_float_vec_field_name = "FloatVector1"
|
|
||||||
default_binary_vec_field_name = "BinaryVector"
|
|
||||||
default_partition_name = "_default"
|
|
||||||
default_tag = "1970_01_01"
|
|
||||||
row_count = "row_count"
|
|
||||||
default_length = 65535
|
|
||||||
default_desc = ""
|
|
||||||
default_collection_desc = "default collection"
|
|
||||||
default_index_name = "default_index_name"
|
|
||||||
default_binary_desc = "default binary collection"
|
|
||||||
collection_desc = "collection"
|
|
||||||
int_field_desc = "int64 type field"
|
|
||||||
float_field_desc = "float type field"
|
|
||||||
float_vec_field_desc = "float vector type field"
|
|
||||||
binary_vec_field_desc = "binary vector type field"
|
|
||||||
max_dim = 32768
|
|
||||||
min_dim = 1
|
|
||||||
gracefulTime = 1
|
|
||||||
default_nlist = 128
|
|
||||||
compact_segment_num_threshold = 4
|
|
||||||
compact_delta_ratio_reciprocal = 5 # compact_delta_binlog_ratio is 0.2
|
|
||||||
compact_retention_duration = 40 # compaction travel time retention range 20s
|
|
||||||
max_compaction_interval = 60 # the max time interval (s) from the last compaction
|
|
||||||
max_field_num = 256 # Maximum number of fields in a collection
|
|
||||||
|
|
||||||
default_dsl = f"{default_int64_field_name} in [2,4,6,8]"
|
|
||||||
default_expr = f"{default_int64_field_name} in [2,4,6,8]"
|
|
||||||
|
|
||||||
metric_types = []
|
|
||||||
all_index_types = ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_PQ", "HNSW", "ANNOY", "DISKANN", "BIN_FLAT", "BIN_IVF_FLAT"]
|
|
||||||
all_index_params_map = {"FLAT": {"index_type": "FLAT", "params": {"nlist": 128}, "metric_type": "L2"},
|
|
||||||
"IVF_FLAT": {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"},
|
|
||||||
"IVF_SQ8": {"index_type": "IVF_SQ8", "params": {"nlist": 128}, "metric_type": "L2"},
|
|
||||||
"IVF_PQ": {"index_type": "IVF_PQ", "params": {"nlist": 128, "m": 16, "nbits": 8},
|
|
||||||
"metric_type": "L2"},
|
|
||||||
"HNSW": {"index_type": "HNSW", "params": {"M": 48, "efConstruction": 500}, "metric_type": "L2"},
|
|
||||||
"ANNOY": {"index_type": "ANNOY", "params": {"n_trees": 50}, "metric_type": "L2"},
|
|
||||||
"DISKANN": {"index_type": "DISKANN", "params": {}, "metric_type": "L2"},
|
|
||||||
"BIN_FLAT": {"index_type": "BIN_FLAT", "params": {"nlist": 128}, "metric_type": "JACCARD"},
|
|
||||||
"BIN_IVF_FLAT": {"index_type": "BIN_IVF_FLAT", "params": {"nlist": 128},
|
|
||||||
"metric_type": "JACCARD"}
|
|
||||||
}
|
|
@ -1,15 +1,12 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import common.common_func as cf
|
import yaml
|
||||||
from check.param_check import ip_check, number_check
|
|
||||||
from config.log_config import log_config
|
|
||||||
from utils.util_log import test_log as log
|
|
||||||
from common.common_func import param_info
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_addoption(parser):
|
def pytest_addoption(parser):
|
||||||
parser.addoption("--host", action="store", default="127.0.0.1", help="Milvus host")
|
parser.addoption("--host", action="store", default="127.0.0.1", help="host")
|
||||||
parser.addoption("--port", action="store", default="9091", help="Milvus http port")
|
parser.addoption("--port", action="store", default="19530", help="port")
|
||||||
parser.addoption('--clean_log', action='store_true', default=False, help="clean log before testing")
|
parser.addoption("--username", action="store", default="root", help="email")
|
||||||
|
parser.addoption("--password", action="store", default="Milvus", help="password")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -23,27 +20,11 @@ def port(request):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def clean_log(request):
|
def username(request):
|
||||||
return request.config.getoption("--clean_log")
|
return request.config.getoption("--username")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture
|
||||||
def initialize_env(request):
|
def password(request):
|
||||||
""" clean log before testing """
|
return request.config.getoption("--password")
|
||||||
host = request.config.getoption("--host")
|
|
||||||
port = request.config.getoption("--port")
|
|
||||||
clean_log = request.config.getoption("--clean_log")
|
|
||||||
|
|
||||||
|
|
||||||
""" params check """
|
|
||||||
assert ip_check(host) and number_check(port)
|
|
||||||
|
|
||||||
""" modify log files """
|
|
||||||
file_path_list = [log_config.log_debug, log_config.log_info, log_config.log_err]
|
|
||||||
if log_config.log_worker != "":
|
|
||||||
file_path_list.append(log_config.log_worker)
|
|
||||||
cf.modify_file(file_path_list=file_path_list, is_modify=clean_log)
|
|
||||||
|
|
||||||
log.info("#" * 80)
|
|
||||||
log.info("[initialize_milvus] Log cleaned up, start testing...")
|
|
||||||
param_info.prepare_param_info(host, port)
|
|
||||||
|
@ -1,3 +0,0 @@
|
|||||||
# generated by datamodel-codegen:
|
|
||||||
# filename: openapi.json
|
|
||||||
# timestamp: 2022-12-08T02:46:08+00:00
|
|
@ -1,28 +0,0 @@
|
|||||||
# generated by datamodel-codegen:
|
|
||||||
# filename: openapi.json
|
|
||||||
# timestamp: 2022-12-08T02:46:08+00:00
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
|
|
||||||
class KeyDataPair(BaseModel):
|
|
||||||
data: Optional[List[int]] = None
|
|
||||||
key: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class KeyValuePair(BaseModel):
|
|
||||||
key: Optional[str] = Field(None, example='dim')
|
|
||||||
value: Optional[str] = Field(None, example='128')
|
|
||||||
|
|
||||||
|
|
||||||
class MsgBase(BaseModel):
|
|
||||||
msg_type: Optional[int] = Field(None, description='Not useful for now')
|
|
||||||
|
|
||||||
|
|
||||||
class Status(BaseModel):
|
|
||||||
error_code: Optional[int] = None
|
|
||||||
reason: Optional[str] = None
|
|
@ -1,138 +0,0 @@
|
|||||||
# generated by datamodel-codegen:
|
|
||||||
# filename: openapi.json
|
|
||||||
# timestamp: 2022-12-08T02:46:08+00:00
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from models import common, schema
|
|
||||||
|
|
||||||
|
|
||||||
class DescribeCollectionRequest(BaseModel):
|
|
||||||
collection_name: Optional[str] = None
|
|
||||||
collectionID: Optional[int] = Field(
|
|
||||||
None, description='The collection ID you want to describe'
|
|
||||||
)
|
|
||||||
time_stamp: Optional[int] = Field(
|
|
||||||
None,
|
|
||||||
description='If time_stamp is not zero, will describe collection success when time_stamp >= created collection timestamp, otherwise will throw error.',
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DropCollectionRequest(BaseModel):
|
|
||||||
collection_name: Optional[str] = Field(
|
|
||||||
None, description='The unique collection name in milvus.(Required)'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FieldData(BaseModel):
|
|
||||||
field: Optional[List] = None
|
|
||||||
field_id: Optional[int] = None
|
|
||||||
field_name: Optional[str] = None
|
|
||||||
type: Optional[int] = Field(
|
|
||||||
None,
|
|
||||||
description='0: "None",\n1: "Bool",\n2: "Int8",\n3: "Int16",\n4: "Int32",\n5: "Int64",\n10: "Float",\n11: "Double",\n20: "String",\n21: "VarChar",\n100: "BinaryVector",\n101: "FloatVector",',
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class GetCollectionStatisticsRequest(BaseModel):
|
|
||||||
collection_name: Optional[str] = Field(
|
|
||||||
None, description='The collection name you want get statistics'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class HasCollectionRequest(BaseModel):
|
|
||||||
collection_name: Optional[str] = Field(
|
|
||||||
None, description='The unique collection name in milvus.(Required)'
|
|
||||||
)
|
|
||||||
time_stamp: Optional[int] = Field(
|
|
||||||
None,
|
|
||||||
description='If time_stamp is not zero, will return true when time_stamp >= created collection timestamp, otherwise will return false.',
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class InsertRequest(BaseModel):
|
|
||||||
base: Optional[common.MsgBase] = None
|
|
||||||
collection_name: Optional[str] = None
|
|
||||||
db_name: Optional[str] = None
|
|
||||||
fields_data: Optional[List[FieldData]] = None
|
|
||||||
hash_keys: Optional[List[int]] = None
|
|
||||||
num_rows: Optional[int] = None
|
|
||||||
partition_name: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class LoadCollectionRequest(BaseModel):
|
|
||||||
collection_name: Optional[str] = Field(
|
|
||||||
None, description='The collection name you want to load'
|
|
||||||
)
|
|
||||||
replica_number: Optional[int] = Field(
|
|
||||||
None, description='The replica number to load, default by 1'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ReleaseCollectionRequest(BaseModel):
|
|
||||||
collection_name: Optional[str] = Field(
|
|
||||||
None, description='The collection name you want to release'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ShowCollectionsRequest(BaseModel):
|
|
||||||
collection_names: Optional[List[str]] = Field(
|
|
||||||
None,
|
|
||||||
description="When type is InMemory, will return these collection's inMemory_percentages.(Optional)",
|
|
||||||
)
|
|
||||||
type: Optional[int] = Field(
|
|
||||||
None,
|
|
||||||
description='Decide return Loaded collections or All collections(Optional)',
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class VectorIDs(BaseModel):
|
|
||||||
collection_name: Optional[str] = None
|
|
||||||
field_name: Optional[str] = None
|
|
||||||
id_array: Optional[List[int]] = None
|
|
||||||
partition_names: Optional[List[str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class VectorsArray(BaseModel):
|
|
||||||
binary_vectors: Optional[List[int]] = Field(
|
|
||||||
None,
|
|
||||||
description='Vectors is an array of binary vector divided by given dim. Disabled when IDs is set',
|
|
||||||
)
|
|
||||||
dim: Optional[int] = Field(
|
|
||||||
None, description='Dim of vectors or binary_vectors, not needed when use ids'
|
|
||||||
)
|
|
||||||
ids: Optional[VectorIDs] = None
|
|
||||||
vectors: Optional[List[float]] = Field(
|
|
||||||
None,
|
|
||||||
description='Vectors is an array of vector divided by given dim. Disabled when ids or binary_vectors is set',
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CalcDistanceRequest(BaseModel):
|
|
||||||
base: Optional[common.MsgBase] = None
|
|
||||||
op_left: Optional[VectorsArray] = None
|
|
||||||
op_right: Optional[VectorsArray] = None
|
|
||||||
params: Optional[List[common.KeyValuePair]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class CreateCollectionRequest(BaseModel):
|
|
||||||
collection_name: str = Field(
|
|
||||||
...,
|
|
||||||
description='The unique collection name in milvus.(Required)',
|
|
||||||
example='book',
|
|
||||||
)
|
|
||||||
consistency_level: int = Field(
|
|
||||||
...,
|
|
||||||
description='The consistency level that the collection used, modification is not supported now.\n"Strong": 0,\n"Session": 1,\n"Bounded": 2,\n"Eventually": 3,\n"Customized": 4,',
|
|
||||||
example=1,
|
|
||||||
)
|
|
||||||
schema_: schema.CollectionSchema = Field(..., alias='schema')
|
|
||||||
shards_num: Optional[int] = Field(
|
|
||||||
None,
|
|
||||||
description='Once set, no modification is allowed (Optional)\nhttps://github.com/milvus-io/milvus/issues/6690',
|
|
||||||
example=1,
|
|
||||||
)
|
|
@ -1,72 +0,0 @@
|
|||||||
# generated by datamodel-codegen:
|
|
||||||
# filename: openapi.json
|
|
||||||
# timestamp: 2022-12-08T02:46:08+00:00
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any, List, Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from models import common
|
|
||||||
|
|
||||||
|
|
||||||
class FieldData(BaseModel):
|
|
||||||
field: Optional[Any] = Field(
|
|
||||||
None,
|
|
||||||
description='Types that are assignable to Field:\n\t*FieldData_Scalars\n\t*FieldData_Vectors',
|
|
||||||
)
|
|
||||||
field_id: Optional[int] = None
|
|
||||||
field_name: Optional[str] = None
|
|
||||||
type: Optional[int] = Field(
|
|
||||||
None,
|
|
||||||
description='0: "None",\n1: "Bool",\n2: "Int8",\n3: "Int16",\n4: "Int32",\n5: "Int64",\n10: "Float",\n11: "Double",\n20: "String",\n21: "VarChar",\n100: "BinaryVector",\n101: "FloatVector",',
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FieldSchema(BaseModel):
|
|
||||||
autoID: Optional[bool] = None
|
|
||||||
data_type: int = Field(
|
|
||||||
...,
|
|
||||||
description='0: "None",\n1: "Bool",\n2: "Int8",\n3: "Int16",\n4: "Int32",\n5: "Int64",\n10: "Float",\n11: "Double",\n20: "String",\n21: "VarChar",\n100: "BinaryVector",\n101: "FloatVector",',
|
|
||||||
example=101,
|
|
||||||
)
|
|
||||||
description: Optional[str] = Field(
|
|
||||||
None, example='embedded vector of book introduction'
|
|
||||||
)
|
|
||||||
fieldID: Optional[int] = None
|
|
||||||
index_params: Optional[List[common.KeyValuePair]] = None
|
|
||||||
is_primary_key: Optional[bool] = Field(None, example=False)
|
|
||||||
name: str = Field(..., example='book_intro')
|
|
||||||
type_params: Optional[List[common.KeyValuePair]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class IDs(BaseModel):
|
|
||||||
idField: Optional[Any] = Field(
|
|
||||||
None,
|
|
||||||
description='Types that are assignable to IdField:\n\t*IDs_IntId\n\t*IDs_StrId',
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LongArray(BaseModel):
|
|
||||||
data: Optional[List[int]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class SearchResultData(BaseModel):
|
|
||||||
fields_data: Optional[List[FieldData]] = None
|
|
||||||
ids: Optional[IDs] = None
|
|
||||||
num_queries: Optional[int] = None
|
|
||||||
scores: Optional[List[float]] = None
|
|
||||||
top_k: Optional[int] = None
|
|
||||||
topks: Optional[List[int]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class CollectionSchema(BaseModel):
|
|
||||||
autoID: Optional[bool] = Field(
|
|
||||||
None,
|
|
||||||
description='deprecated later, keep compatible with c++ part now',
|
|
||||||
example=False,
|
|
||||||
)
|
|
||||||
description: Optional[str] = Field(None, example='Test book search')
|
|
||||||
fields: Optional[List[FieldSchema]] = None
|
|
||||||
name: str = Field(..., example='book')
|
|
@ -1,587 +0,0 @@
|
|||||||
# generated by datamodel-codegen:
|
|
||||||
# filename: openapi.json
|
|
||||||
# timestamp: 2022-12-08T02:46:08+00:00
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from models import common, schema
|
|
||||||
|
|
||||||
|
|
||||||
class AlterAliasRequest(BaseModel):
|
|
||||||
alias: Optional[str] = None
|
|
||||||
base: Optional[common.MsgBase] = None
|
|
||||||
collection_name: Optional[str] = None
|
|
||||||
db_name: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class BoolResponse(BaseModel):
|
|
||||||
status: Optional[common.Status] = None
|
|
||||||
value: Optional[bool] = None
|
|
||||||
|
|
||||||
|
|
||||||
class CalcDistanceResults(BaseModel):
|
|
||||||
array: Optional[Any] = Field(
|
|
||||||
None,
|
|
||||||
description='num(op_left)*num(op_right) distance values, "HAMMIN" return integer distance\n\nTypes that are assignable to Array:\n\t*CalcDistanceResults_IntDist\n\t*CalcDistanceResults_FloatDist',
|
|
||||||
)
|
|
||||||
status: Optional[common.Status] = None
|
|
||||||
|
|
||||||
|
|
||||||
class CompactionMergeInfo(BaseModel):
|
|
||||||
sources: Optional[List[int]] = None
|
|
||||||
target: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
class CreateAliasRequest(BaseModel):
|
|
||||||
alias: Optional[str] = None
|
|
||||||
base: Optional[common.MsgBase] = None
|
|
||||||
collection_name: Optional[str] = None
|
|
||||||
db_name: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class CreateCredentialRequest(BaseModel):
|
|
||||||
base: Optional[common.MsgBase] = None
|
|
||||||
created_utc_timestamps: Optional[int] = Field(None, description='create time')
|
|
||||||
modified_utc_timestamps: Optional[int] = Field(None, description='modify time')
|
|
||||||
password: Optional[str] = Field(None, description='ciphertext password')
|
|
||||||
username: Optional[str] = Field(None, description='username')
|
|
||||||
|
|
||||||
|
|
||||||
class CreateIndexRequest(BaseModel):
|
|
||||||
base: Optional[common.MsgBase] = None
|
|
||||||
collection_name: Optional[str] = Field(
|
|
||||||
None, description='The particular collection name you want to create index.'
|
|
||||||
)
|
|
||||||
db_name: Optional[str] = Field(None, description='Not useful for now')
|
|
||||||
extra_params: Optional[List[common.KeyValuePair]] = Field(
|
|
||||||
None,
|
|
||||||
description='Support keys: index_type,metric_type, params. Different index_type may has different params.',
|
|
||||||
)
|
|
||||||
field_name: Optional[str] = Field(
|
|
||||||
None, description='The vector field name in this particular collection'
|
|
||||||
)
|
|
||||||
index_name: Optional[str] = Field(
|
|
||||||
None,
|
|
||||||
description="Version before 2.0.2 doesn't contain index_name, we use default index name.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CreatePartitionRequest(BaseModel):
|
|
||||||
base: Optional[common.MsgBase] = None
|
|
||||||
collection_name: Optional[str] = Field(
|
|
||||||
None, description='The collection name in milvus'
|
|
||||||
)
|
|
||||||
db_name: Optional[str] = Field(None, description='Not useful for now')
|
|
||||||
partition_name: Optional[str] = Field(
|
|
||||||
None, description='The partition name you want to create.'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DeleteCredentialRequest(BaseModel):
|
|
||||||
base: Optional[common.MsgBase] = None
|
|
||||||
username: Optional[str] = Field(None, description='Not useful for now')
|
|
||||||
|
|
||||||
|
|
||||||
class DeleteRequest(BaseModel):
|
|
||||||
base: Optional[common.MsgBase] = None
|
|
||||||
collection_name: Optional[str] = None
|
|
||||||
db_name: Optional[str] = None
|
|
||||||
expr: Optional[str] = None
|
|
||||||
hash_keys: Optional[List[int]] = None
|
|
||||||
partition_name: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class DescribeIndexRequest(BaseModel):
|
|
||||||
base: Optional[common.MsgBase] = None
|
|
||||||
collection_name: Optional[str] = Field(
|
|
||||||
None, description='The particular collection name in Milvus'
|
|
||||||
)
|
|
||||||
db_name: Optional[str] = Field(None, description='Not useful for now')
|
|
||||||
field_name: Optional[str] = Field(
|
|
||||||
None, description='The vector field name in this particular collection'
|
|
||||||
)
|
|
||||||
index_name: Optional[str] = Field(
|
|
||||||
None, description='No need to set up for now @2021.06.30'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DropAliasRequest(BaseModel):
|
|
||||||
alias: Optional[str] = None
|
|
||||||
base: Optional[common.MsgBase] = None
|
|
||||||
db_name: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class DropIndexRequest(BaseModel):
|
|
||||||
base: Optional[common.MsgBase] = None
|
|
||||||
collection_name: Optional[str] = Field(None, description='must')
|
|
||||||
db_name: Optional[str] = None
|
|
||||||
field_name: Optional[str] = None
|
|
||||||
index_name: Optional[str] = Field(
|
|
||||||
None, description='No need to set up for now @2021.06.30'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DropPartitionRequest(BaseModel):
|
|
||||||
base: Optional[common.MsgBase] = None
|
|
||||||
collection_name: Optional[str] = Field(
|
|
||||||
None, description='The collection name in milvus'
|
|
||||||
)
|
|
||||||
db_name: Optional[str] = Field(None, description='Not useful for now')
|
|
||||||
partition_name: Optional[str] = Field(
|
|
||||||
None, description='The partition name you want to drop'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FlushRequest(BaseModel):
|
|
||||||
base: Optional[common.MsgBase] = None
|
|
||||||
collection_names: Optional[List[str]] = None
|
|
||||||
db_name: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class FlushResponse(BaseModel):
|
|
||||||
coll_segIDs: Optional[Dict[str, schema.LongArray]] = None
|
|
||||||
db_name: Optional[str] = None
|
|
||||||
status: Optional[common.Status] = None
|
|
||||||
|
|
||||||
|
|
||||||
class GetCollectionStatisticsResponse(BaseModel):
|
|
||||||
stats: Optional[List[common.KeyValuePair]] = Field(
|
|
||||||
None, description='Collection statistics data'
|
|
||||||
)
|
|
||||||
status: Optional[common.Status] = None
|
|
||||||
|
|
||||||
|
|
||||||
class GetCompactionPlansRequest(BaseModel):
|
|
||||||
compactionID: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
class GetCompactionPlansResponse(BaseModel):
|
|
||||||
mergeInfos: Optional[List[CompactionMergeInfo]] = None
|
|
||||||
state: Optional[int] = None
|
|
||||||
status: Optional[common.Status] = None
|
|
||||||
|
|
||||||
|
|
||||||
class GetCompactionStateRequest(BaseModel):
|
|
||||||
compactionID: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
class GetCompactionStateResponse(BaseModel):
|
|
||||||
completedPlanNo: Optional[int] = None
|
|
||||||
executingPlanNo: Optional[int] = None
|
|
||||||
state: Optional[int] = None
|
|
||||||
status: Optional[common.Status] = None
|
|
||||||
timeoutPlanNo: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
class GetFlushStateRequest(BaseModel):
|
|
||||||
segmentIDs: Optional[List[int]] = Field(None, alias='segment_ids')
|
|
||||||
|
|
||||||
|
|
||||||
class GetFlushStateResponse(BaseModel):
|
|
||||||
flushed: Optional[bool] = None
|
|
||||||
status: Optional[common.Status] = None
|
|
||||||
|
|
||||||
|
|
||||||
class GetImportStateRequest(BaseModel):
|
|
||||||
task: Optional[int] = Field(None, description='id of an import task')
|
|
||||||
|
|
||||||
|
|
||||||
class GetImportStateResponse(BaseModel):
|
|
||||||
id: Optional[int] = Field(None, description='id of an import task')
|
|
||||||
id_list: Optional[List[int]] = Field(
|
|
||||||
None, description='auto generated ids if the primary key is autoid'
|
|
||||||
)
|
|
||||||
infos: Optional[List[common.KeyValuePair]] = Field(
|
|
||||||
None,
|
|
||||||
description='more informations about the task, progress percent, file path, failed reason, etc.',
|
|
||||||
)
|
|
||||||
row_count: Optional[int] = Field(
|
|
||||||
None,
|
|
||||||
description='if the task is finished, this value is how many rows are imported. if the task is not finished, this value is how many rows are parsed. return 0 if failed.',
|
|
||||||
)
|
|
||||||
state: Optional[int] = Field(
|
|
||||||
None, description='is this import task finished or not'
|
|
||||||
)
|
|
||||||
status: Optional[common.Status] = None
|
|
||||||
|
|
||||||
|
|
||||||
class GetIndexBuildProgressRequest(BaseModel):
|
|
||||||
base: Optional[common.MsgBase] = None
|
|
||||||
collection_name: Optional[str] = Field(
|
|
||||||
None, description='The collection name in milvus'
|
|
||||||
)
|
|
||||||
db_name: Optional[str] = Field(None, description='Not useful for now')
|
|
||||||
field_name: Optional[str] = Field(
|
|
||||||
None, description='The vector field name in this collection'
|
|
||||||
)
|
|
||||||
index_name: Optional[str] = Field(None, description='Not useful for now')
|
|
||||||
|
|
||||||
|
|
||||||
class GetIndexBuildProgressResponse(BaseModel):
|
|
||||||
indexed_rows: Optional[int] = None
|
|
||||||
status: Optional[common.Status] = None
|
|
||||||
total_rows: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
class GetIndexStateRequest(BaseModel):
|
|
||||||
base: Optional[common.MsgBase] = None
|
|
||||||
collection_name: Optional[str] = Field(None, description='must')
|
|
||||||
db_name: Optional[str] = None
|
|
||||||
field_name: Optional[str] = None
|
|
||||||
index_name: Optional[str] = Field(
|
|
||||||
None, description='No need to set up for now @2021.06.30'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class GetIndexStateResponse(BaseModel):
|
|
||||||
fail_reason: Optional[str] = None
|
|
||||||
state: Optional[int] = None
|
|
||||||
status: Optional[common.Status] = None
|
|
||||||
|
|
||||||
|
|
||||||
class GetMetricsRequest(BaseModel):
|
|
||||||
base: Optional[common.MsgBase] = None
|
|
||||||
request: Optional[str] = Field(None, description='request is of jsonic format')
|
|
||||||
|
|
||||||
|
|
||||||
class GetMetricsResponse(BaseModel):
|
|
||||||
component_name: Optional[str] = Field(
|
|
||||||
None, description='metrics from which component'
|
|
||||||
)
|
|
||||||
response: Optional[str] = Field(None, description='response is of jsonic format')
|
|
||||||
status: Optional[common.Status] = None
|
|
||||||
|
|
||||||
|
|
||||||
class GetPartitionStatisticsRequest(BaseModel):
|
|
||||||
base: Optional[common.MsgBase] = None
|
|
||||||
collection_name: Optional[str] = Field(
|
|
||||||
None, description='The collection name in milvus'
|
|
||||||
)
|
|
||||||
db_name: Optional[str] = Field(None, description='Not useful for now')
|
|
||||||
partition_name: Optional[str] = Field(
|
|
||||||
None, description='The partition name you want to collect statistics'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class GetPartitionStatisticsResponse(BaseModel):
|
|
||||||
stats: Optional[List[common.KeyValuePair]] = None
|
|
||||||
status: Optional[common.Status] = None
|
|
||||||
|
|
||||||
|
|
||||||
class GetPersistentSegmentInfoRequest(BaseModel):
|
|
||||||
base: Optional[common.MsgBase] = None
|
|
||||||
collectionName: Optional[str] = Field(None, alias="collection_name", description='must')
|
|
||||||
dbName: Optional[str] = Field(None, alias="db_name")
|
|
||||||
|
|
||||||
|
|
||||||
class GetQuerySegmentInfoRequest(BaseModel):
|
|
||||||
base: Optional[common.MsgBase] = None
|
|
||||||
collectionName: Optional[str] = Field(None, alias="collection_name", description='must')
|
|
||||||
dbName: Optional[str] = Field(None, alias="db_name")
|
|
||||||
|
|
||||||
|
|
||||||
class GetReplicasRequest(BaseModel):
|
|
||||||
base: Optional[common.MsgBase] = None
|
|
||||||
collectionID: Optional[int] = Field(None, alias="collection_id")
|
|
||||||
with_shard_nodes: Optional[bool] = None
|
|
||||||
|
|
||||||
|
|
||||||
class HasPartitionRequest(BaseModel):
|
|
||||||
base: Optional[common.MsgBase] = None
|
|
||||||
collection_name: Optional[str] = Field(
|
|
||||||
None, description='The collection name in milvus'
|
|
||||||
)
|
|
||||||
db_name: Optional[str] = Field(None, description='Not useful for now')
|
|
||||||
partition_name: Optional[str] = Field(
|
|
||||||
None, description='The partition name you want to check'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ImportRequest(BaseModel):
|
|
||||||
channel_names: Optional[List[str]] = Field(
|
|
||||||
None, description='channel names for the collection'
|
|
||||||
)
|
|
||||||
collection_name: Optional[str] = Field(None, description='target collection')
|
|
||||||
files: Optional[List[str]] = Field(None, description='file paths to be imported')
|
|
||||||
options: Optional[List[common.KeyValuePair]] = Field(
|
|
||||||
None, description='import options, bucket, etc.'
|
|
||||||
)
|
|
||||||
partition_name: Optional[str] = Field(None, description='target partition')
|
|
||||||
row_based: Optional[bool] = Field(
|
|
||||||
None, description='the file is row-based or column-based'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ImportResponse(BaseModel):
|
|
||||||
status: Optional[common.Status] = None
|
|
||||||
tasks: Optional[List[int]] = Field(None, description='id array of import tasks')
|
|
||||||
|
|
||||||
|
|
||||||
class IndexDescription(BaseModel):
|
|
||||||
field_name: Optional[str] = Field(None, description='The vector field name')
|
|
||||||
index_name: Optional[str] = Field(None, description='Index name')
|
|
||||||
indexID: Optional[int] = Field(None, description='Index id')
|
|
||||||
params: Optional[List[common.KeyValuePair]] = Field(
|
|
||||||
None, description='Will return index_type, metric_type, params(like nlist).'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ListCredUsersRequest(BaseModel):
|
|
||||||
base: Optional[common.MsgBase] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ListCredUsersResponse(BaseModel):
|
|
||||||
status: Optional[common.Status] = None
|
|
||||||
usernames: Optional[List[str]] = Field(None, description='username array')
|
|
||||||
|
|
||||||
|
|
||||||
class ListImportTasksRequest(BaseModel):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ListImportTasksResponse(BaseModel):
|
|
||||||
status: Optional[common.Status] = None
|
|
||||||
tasks: Optional[List[GetImportStateResponse]] = Field(
|
|
||||||
None, description='list of all import tasks'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LoadBalanceRequest(BaseModel):
|
|
||||||
base: Optional[common.MsgBase] = None
|
|
||||||
collectionName: Optional[str] = None
|
|
||||||
dst_nodeIDs: Optional[List[int]] = None
|
|
||||||
sealed_segmentIDs: Optional[List[int]] = None
|
|
||||||
src_nodeID: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
class LoadPartitionsRequest(BaseModel):
|
|
||||||
base: Optional[common.MsgBase] = None
|
|
||||||
collection_name: Optional[str] = Field(
|
|
||||||
None, description='The collection name in milvus'
|
|
||||||
)
|
|
||||||
db_name: Optional[str] = Field(None, description='Not useful for now')
|
|
||||||
partition_names: Optional[List[str]] = Field(
|
|
||||||
None, description='The partition names you want to load'
|
|
||||||
)
|
|
||||||
replica_number: Optional[int] = Field(
|
|
||||||
None, description='The replicas number you would load, 1 by default'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ManualCompactionRequest(BaseModel):
|
|
||||||
collectionID: Optional[int] = None
|
|
||||||
timetravel: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ManualCompactionResponse(BaseModel):
|
|
||||||
compactionID: Optional[int] = None
|
|
||||||
status: Optional[common.Status] = None
|
|
||||||
|
|
||||||
|
|
||||||
class PersistentSegmentInfo(BaseModel):
|
|
||||||
collectionID: Optional[int] = None
|
|
||||||
num_rows: Optional[int] = None
|
|
||||||
partitionID: Optional[int] = None
|
|
||||||
segmentID: Optional[int] = None
|
|
||||||
state: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
class QueryRequest(BaseModel):
|
|
||||||
base: Optional[common.MsgBase] = None
|
|
||||||
collection_name: Optional[str] = None
|
|
||||||
db_name: Optional[str] = None
|
|
||||||
expr: Optional[str] = None
|
|
||||||
guarantee_timestamp: Optional[int] = Field(None, description='guarantee_timestamp')
|
|
||||||
output_fields: Optional[List[str]] = None
|
|
||||||
partition_names: Optional[List[str]] = None
|
|
||||||
travel_timestamp: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
class QueryResults(BaseModel):
|
|
||||||
collection_name: Optional[str] = None
|
|
||||||
fields_data: Optional[List[schema.FieldData]] = None
|
|
||||||
status: Optional[common.Status] = None
|
|
||||||
|
|
||||||
|
|
||||||
class QuerySegmentInfo(BaseModel):
|
|
||||||
collectionID: Optional[int] = None
|
|
||||||
index_name: Optional[str] = None
|
|
||||||
indexID: Optional[int] = None
|
|
||||||
mem_size: Optional[int] = None
|
|
||||||
nodeID: Optional[int] = None
|
|
||||||
num_rows: Optional[int] = None
|
|
||||||
partitionID: Optional[int] = None
|
|
||||||
segmentID: Optional[int] = None
|
|
||||||
state: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ReleasePartitionsRequest(BaseModel):
|
|
||||||
base: Optional[common.MsgBase] = None
|
|
||||||
collection_name: Optional[str] = Field(
|
|
||||||
None, description='The collection name in milvus'
|
|
||||||
)
|
|
||||||
db_name: Optional[str] = Field(None, description='Not useful for now')
|
|
||||||
partition_names: Optional[List[str]] = Field(
|
|
||||||
None, description='The partition names you want to release'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SearchRequest(BaseModel):
|
|
||||||
base: Optional[common.MsgBase] = None
|
|
||||||
collection_name: Optional[str] = Field(None, description='must')
|
|
||||||
db_name: Optional[str] = None
|
|
||||||
dsl: Optional[str] = Field(None, description='must')
|
|
||||||
dsl_type: Optional[int] = Field(None, description='must')
|
|
||||||
guarantee_timestamp: Optional[int] = Field(None, description='guarantee_timestamp')
|
|
||||||
output_fields: Optional[List[str]] = None
|
|
||||||
partition_names: Optional[List[str]] = Field(None, description='must')
|
|
||||||
placeholder_group: Optional[List[int]] = Field(
|
|
||||||
None, description='serialized `PlaceholderGroup`'
|
|
||||||
)
|
|
||||||
search_params: Optional[List[common.KeyValuePair]] = Field(None, description='must')
|
|
||||||
travel_timestamp: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
class SearchResults(BaseModel):
|
|
||||||
collection_name: Optional[str] = None
|
|
||||||
results: Optional[schema.SearchResultData] = None
|
|
||||||
status: Optional[common.Status] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ShardReplica(BaseModel):
|
|
||||||
dm_channel_name: Optional[str] = None
|
|
||||||
leader_addr: Optional[str] = Field(None, description='IP:port')
|
|
||||||
leaderID: Optional[int] = None
|
|
||||||
node_ids: Optional[List[int]] = Field(
|
|
||||||
None,
|
|
||||||
description='optional, DO NOT save it in meta, set it only for GetReplicas()\nif with_shard_nodes is true',
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ShowCollectionsResponse(BaseModel):
|
|
||||||
collection_ids: Optional[List[int]] = Field(None, description='Collection Id array')
|
|
||||||
collection_names: Optional[List[str]] = Field(
|
|
||||||
None, description='Collection name array'
|
|
||||||
)
|
|
||||||
created_timestamps: Optional[List[int]] = Field(
|
|
||||||
None, description='Hybrid timestamps in milvus'
|
|
||||||
)
|
|
||||||
created_utc_timestamps: Optional[List[int]] = Field(
|
|
||||||
None, description='The utc timestamp calculated by created_timestamp'
|
|
||||||
)
|
|
||||||
inMemory_percentages: Optional[List[int]] = Field(
|
|
||||||
None, description='Load percentage on querynode when type is InMemory'
|
|
||||||
)
|
|
||||||
status: Optional[common.Status] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ShowPartitionsRequest(BaseModel):
|
|
||||||
base: Optional[common.MsgBase] = None
|
|
||||||
collection_name: Optional[str] = Field(
|
|
||||||
None,
|
|
||||||
description='The collection name you want to describe, you can pass collection_name or collectionID',
|
|
||||||
)
|
|
||||||
collectionID: Optional[int] = Field(None, description='The collection id in milvus')
|
|
||||||
db_name: Optional[str] = Field(None, description='Not useful for now')
|
|
||||||
partition_names: Optional[List[str]] = Field(
|
|
||||||
None,
|
|
||||||
description="When type is InMemory, will return these patitions' inMemory_percentages.(Optional)",
|
|
||||||
)
|
|
||||||
type: Optional[int] = Field(
|
|
||||||
None, description='Decide return Loaded partitions or All partitions(Optional)'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ShowPartitionsResponse(BaseModel):
|
|
||||||
created_timestamps: Optional[List[int]] = Field(
|
|
||||||
None, description='All hybrid timestamps'
|
|
||||||
)
|
|
||||||
created_utc_timestamps: Optional[List[int]] = Field(
|
|
||||||
None, description='All utc timestamps calculated by created_timestamps'
|
|
||||||
)
|
|
||||||
inMemory_percentages: Optional[List[int]] = Field(
|
|
||||||
None, description='Load percentage on querynode'
|
|
||||||
)
|
|
||||||
partition_names: Optional[List[str]] = Field(
|
|
||||||
None, description='All partition names for this collection'
|
|
||||||
)
|
|
||||||
partitionIDs: Optional[List[int]] = Field(
|
|
||||||
None, description='All partition ids for this collection'
|
|
||||||
)
|
|
||||||
status: Optional[common.Status] = None
|
|
||||||
|
|
||||||
|
|
||||||
class UpdateCredentialRequest(BaseModel):
|
|
||||||
base: Optional[common.MsgBase] = None
|
|
||||||
created_utc_timestamps: Optional[int] = Field(None, description='create time')
|
|
||||||
modified_utc_timestamps: Optional[int] = Field(None, description='modify time')
|
|
||||||
newPassword: Optional[str] = Field(None, description='new password')
|
|
||||||
oldPassword: Optional[str] = Field(None, description='old password')
|
|
||||||
username: Optional[str] = Field(None, description='username')
|
|
||||||
|
|
||||||
|
|
||||||
class DescribeCollectionResponse(BaseModel):
|
|
||||||
aliases: Optional[List[str]] = Field(
|
|
||||||
None, description='The aliases of this collection'
|
|
||||||
)
|
|
||||||
collection_name: Optional[str] = Field(None, description='The collection name')
|
|
||||||
collectionID: Optional[int] = Field(None, description='The collection id')
|
|
||||||
consistency_level: Optional[int] = Field(
|
|
||||||
None,
|
|
||||||
description='The consistency level that the collection used, modification is not supported now.',
|
|
||||||
)
|
|
||||||
created_timestamp: Optional[int] = Field(
|
|
||||||
None, description='Hybrid timestamp in milvus'
|
|
||||||
)
|
|
||||||
created_utc_timestamp: Optional[int] = Field(
|
|
||||||
None, description='The utc timestamp calculated by created_timestamp'
|
|
||||||
)
|
|
||||||
physical_channel_names: Optional[List[str]] = Field(
|
|
||||||
None, description='System design related, users should not perceive'
|
|
||||||
)
|
|
||||||
schema_: Optional[schema.CollectionSchema] = Field(None, alias='schema')
|
|
||||||
shards_num: Optional[int] = Field(None, description='The shards number you set.')
|
|
||||||
start_positions: Optional[List[common.KeyDataPair]] = Field(
|
|
||||||
None, description='The message ID/posititon when collection is created'
|
|
||||||
)
|
|
||||||
status: Optional[common.Status] = None
|
|
||||||
virtual_channel_names: Optional[List[str]] = Field(
|
|
||||||
None, description='System design related, users should not perceive'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DescribeIndexResponse(BaseModel):
|
|
||||||
index_descriptions: Optional[List[IndexDescription]] = Field(
|
|
||||||
None,
|
|
||||||
description='All index informations, for now only return tha latest index you created for the collection.',
|
|
||||||
)
|
|
||||||
status: Optional[common.Status] = None
|
|
||||||
|
|
||||||
|
|
||||||
class GetPersistentSegmentInfoResponse(BaseModel):
|
|
||||||
infos: Optional[List[PersistentSegmentInfo]] = None
|
|
||||||
status: Optional[common.Status] = None
|
|
||||||
|
|
||||||
|
|
||||||
class GetQuerySegmentInfoResponse(BaseModel):
|
|
||||||
infos: Optional[List[QuerySegmentInfo]] = None
|
|
||||||
status: Optional[common.Status] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ReplicaInfo(BaseModel):
|
|
||||||
collectionID: Optional[int] = None
|
|
||||||
node_ids: Optional[List[int]] = Field(None, description='include leaders')
|
|
||||||
partition_ids: Optional[List[int]] = Field(
|
|
||||||
None, description='empty indicates to load collection'
|
|
||||||
)
|
|
||||||
replicaID: Optional[int] = None
|
|
||||||
shard_replicas: Optional[List[ShardReplica]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class GetReplicasResponse(BaseModel):
|
|
||||||
replicas: Optional[List[ReplicaInfo]] = None
|
|
||||||
status: Optional[common.Status] = None
|
|
@ -1,12 +1,14 @@
|
|||||||
[pytest]
|
[pytest]
|
||||||
|
addopts = --strict --host 127.0.0.1 --port 19530 --username root --password Milvus --log-cli-level=INFO --capture=no
|
||||||
|
|
||||||
addopts = --host 10.101.178.131 --html=/tmp/ci_logs/report.html --self-contained-html -v
|
|
||||||
# python3 -W ignore -m pytest
|
|
||||||
|
|
||||||
log_format = [%(asctime)s - %(levelname)s - %(name)s]: %(message)s (%(filename)s:%(lineno)s)
|
log_format = [%(asctime)s - %(levelname)s - %(name)s]: %(message)s (%(filename)s:%(lineno)s)
|
||||||
log_date_format = %Y-%m-%d %H:%M:%S
|
log_date_format = %Y-%m-%d %H:%M:%S
|
||||||
|
|
||||||
|
|
||||||
filterwarnings =
|
filterwarnings =
|
||||||
ignore::DeprecationWarning
|
ignore::DeprecationWarning
|
||||||
|
|
||||||
|
markers =
|
||||||
|
L0 : 'L0 case, high priority'
|
||||||
|
L1 : 'L1 case, second priority'
|
||||||
|
|
||||||
|
@ -1,2 +1,10 @@
|
|||||||
decorest~=0.1.0
|
requests~=2.26.0
|
||||||
pydantic~=1.10.2
|
urllib3==1.26.16
|
||||||
|
loguru~=0.5.3
|
||||||
|
pytest~=7.2.0
|
||||||
|
pyyaml~=6.0
|
||||||
|
numpy~=1.24.3
|
||||||
|
allure-pytest>=2.8.18
|
||||||
|
Faker==19.2.0
|
||||||
|
pymilvus~=2.2.9
|
||||||
|
scikit-learn~=1.1.3
|
395
tests/restful_client/testcases/test_collection_operations.py
Normal file
395
tests/restful_client/testcases/test_collection_operations.py
Normal file
@ -0,0 +1,395 @@
|
|||||||
|
import datetime
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
from utils.util_log import test_log as logger
|
||||||
|
from utils.utils import gen_collection_name
|
||||||
|
import pytest
|
||||||
|
from api.milvus import CollectionClient
|
||||||
|
from base.testbase import TestBase
|
||||||
|
import threading
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.L0
|
||||||
|
class TestCreateCollection(TestBase):
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("vector_field", [None, "vector", "emb"])
|
||||||
|
@pytest.mark.parametrize("primary_field", [None, "id", "doc_id"])
|
||||||
|
@pytest.mark.parametrize("metric_type", ["L2", "IP"])
|
||||||
|
@pytest.mark.parametrize("dim", [32, 32768])
|
||||||
|
@pytest.mark.parametrize("db_name", ["prod", "default"])
|
||||||
|
def test_create_collections_default(self, dim, metric_type, primary_field, vector_field, db_name):
|
||||||
|
"""
|
||||||
|
target: test create collection
|
||||||
|
method: create a collection with a simple schema
|
||||||
|
expected: create collection success
|
||||||
|
"""
|
||||||
|
self.create_database(db_name)
|
||||||
|
name = gen_collection_name()
|
||||||
|
dim = 128
|
||||||
|
client = self.collection_client
|
||||||
|
client.db_name = db_name
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"dimension": dim,
|
||||||
|
"metricType": metric_type,
|
||||||
|
"primaryField": primary_field,
|
||||||
|
"vectorField": vector_field,
|
||||||
|
}
|
||||||
|
if primary_field is None:
|
||||||
|
del payload["primaryField"]
|
||||||
|
if vector_field is None:
|
||||||
|
del payload["vectorField"]
|
||||||
|
rsp = client.collection_create(payload)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
rsp = client.collection_list()
|
||||||
|
|
||||||
|
all_collections = rsp['data']
|
||||||
|
assert name in all_collections
|
||||||
|
# describe collection
|
||||||
|
rsp = client.collection_describe(name)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
assert rsp['data']['collectionName'] == name
|
||||||
|
|
||||||
|
def test_create_collections_concurrent_with_same_param(self):
|
||||||
|
"""
|
||||||
|
target: test create collection with same param
|
||||||
|
method: concurrent create collections with same param with multi thread
|
||||||
|
expected: create collections all success
|
||||||
|
"""
|
||||||
|
concurrent_rsp = []
|
||||||
|
|
||||||
|
def create_collection(c_name, vector_dim, c_metric_type):
|
||||||
|
collection_payload = {
|
||||||
|
"collectionName": c_name,
|
||||||
|
"dimension": vector_dim,
|
||||||
|
"metricType": c_metric_type,
|
||||||
|
}
|
||||||
|
rsp = client.collection_create(collection_payload)
|
||||||
|
concurrent_rsp.append(rsp)
|
||||||
|
logger.info(rsp)
|
||||||
|
name = gen_collection_name()
|
||||||
|
dim = 128
|
||||||
|
metric_type = "L2"
|
||||||
|
client = self.collection_client
|
||||||
|
threads = []
|
||||||
|
for i in range(10):
|
||||||
|
t = threading.Thread(target=create_collection, args=(name, dim, metric_type,))
|
||||||
|
threads.append(t)
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
time.sleep(10)
|
||||||
|
success_cnt = 0
|
||||||
|
for rsp in concurrent_rsp:
|
||||||
|
if rsp["code"] == 200:
|
||||||
|
success_cnt += 1
|
||||||
|
logger.info(concurrent_rsp)
|
||||||
|
assert success_cnt == 10
|
||||||
|
rsp = client.collection_list()
|
||||||
|
all_collections = rsp['data']
|
||||||
|
assert name in all_collections
|
||||||
|
# describe collection
|
||||||
|
rsp = client.collection_describe(name)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
assert rsp['data']['collectionName'] == name
|
||||||
|
assert f"FloatVector({dim})" in str(rsp['data']['fields'])
|
||||||
|
|
||||||
|
def test_create_collections_concurrent_with_different_param(self):
|
||||||
|
"""
|
||||||
|
target: test create collection with different param
|
||||||
|
method: concurrent create collections with different param with multi thread
|
||||||
|
expected: only one collection can success
|
||||||
|
"""
|
||||||
|
concurrent_rsp = []
|
||||||
|
|
||||||
|
def create_collection(c_name, vector_dim, c_metric_type):
|
||||||
|
collection_payload = {
|
||||||
|
"collectionName": c_name,
|
||||||
|
"dimension": vector_dim,
|
||||||
|
"metricType": c_metric_type,
|
||||||
|
}
|
||||||
|
rsp = client.collection_create(collection_payload)
|
||||||
|
concurrent_rsp.append(rsp)
|
||||||
|
logger.info(rsp)
|
||||||
|
name = gen_collection_name()
|
||||||
|
dim = 128
|
||||||
|
client = self.collection_client
|
||||||
|
threads = []
|
||||||
|
for i in range(0, 5):
|
||||||
|
t = threading.Thread(target=create_collection, args=(name, dim + i, "L2",))
|
||||||
|
threads.append(t)
|
||||||
|
for i in range(5, 10):
|
||||||
|
t = threading.Thread(target=create_collection, args=(name, dim + i, "IP",))
|
||||||
|
threads.append(t)
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
time.sleep(10)
|
||||||
|
success_cnt = 0
|
||||||
|
for rsp in concurrent_rsp:
|
||||||
|
if rsp["code"] == 200:
|
||||||
|
success_cnt += 1
|
||||||
|
logger.info(concurrent_rsp)
|
||||||
|
assert success_cnt == 1
|
||||||
|
rsp = client.collection_list()
|
||||||
|
all_collections = rsp['data']
|
||||||
|
assert name in all_collections
|
||||||
|
# describe collection
|
||||||
|
rsp = client.collection_describe(name)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
assert rsp['data']['collectionName'] == name
|
||||||
|
|
||||||
|
def test_create_collections_with_invalid_api_key(self):
|
||||||
|
"""
|
||||||
|
target: test create collection with invalid api key(wrong username and password)
|
||||||
|
method: create collections with invalid api key
|
||||||
|
expected: create collection failed
|
||||||
|
"""
|
||||||
|
name = gen_collection_name()
|
||||||
|
dim = 128
|
||||||
|
client = self.collection_client
|
||||||
|
client.api_key = "illegal_api_key"
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"dimension": dim,
|
||||||
|
}
|
||||||
|
rsp = client.collection_create(payload)
|
||||||
|
assert rsp['code'] == 1800
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("name", [" ", "test_collection_" * 100, "test collection", "test/collection", "test\collection"])
|
||||||
|
def test_create_collections_with_invalid_collection_name(self, name):
|
||||||
|
"""
|
||||||
|
target: test create collection with invalid collection name
|
||||||
|
method: create collections with invalid collection name
|
||||||
|
expected: create collection failed with right error message
|
||||||
|
"""
|
||||||
|
dim = 128
|
||||||
|
client = self.collection_client
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"dimension": dim,
|
||||||
|
}
|
||||||
|
rsp = client.collection_create(payload)
|
||||||
|
assert rsp['code'] == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.L0
|
||||||
|
class TestListCollections(TestBase):
|
||||||
|
|
||||||
|
def test_list_collections_default(self):
|
||||||
|
"""
|
||||||
|
target: test list collection with a simple schema
|
||||||
|
method: create collections and list them
|
||||||
|
expected: created collections are in list
|
||||||
|
"""
|
||||||
|
client = self.collection_client
|
||||||
|
name_list = []
|
||||||
|
for i in range(2):
|
||||||
|
name = gen_collection_name()
|
||||||
|
dim = 128
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"dimension": dim,
|
||||||
|
}
|
||||||
|
time.sleep(1)
|
||||||
|
rsp = client.collection_create(payload)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
name_list.append(name)
|
||||||
|
rsp = client.collection_list()
|
||||||
|
all_collections = rsp['data']
|
||||||
|
for name in name_list:
|
||||||
|
assert name in all_collections
|
||||||
|
|
||||||
|
def test_list_collections_with_invalid_api_key(self):
|
||||||
|
"""
|
||||||
|
target: test list collection with an invalid api key
|
||||||
|
method: list collection with invalid api key
|
||||||
|
expected: raise error with right error code and message
|
||||||
|
"""
|
||||||
|
client = self.collection_client
|
||||||
|
name_list = []
|
||||||
|
for i in range(2):
|
||||||
|
name = gen_collection_name()
|
||||||
|
dim = 128
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"dimension": dim,
|
||||||
|
}
|
||||||
|
time.sleep(1)
|
||||||
|
rsp = client.collection_create(payload)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
name_list.append(name)
|
||||||
|
client = self.collection_client
|
||||||
|
client.api_key = "illegal_api_key"
|
||||||
|
rsp = client.collection_list()
|
||||||
|
assert rsp['code'] == 1800
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.L0
|
||||||
|
class TestDescribeCollection(TestBase):
|
||||||
|
|
||||||
|
|
||||||
|
def test_describe_collections_default(self):
|
||||||
|
"""
|
||||||
|
target: test describe collection with a simple schema
|
||||||
|
method: describe collection
|
||||||
|
expected: info of description is same with param passed to create collection
|
||||||
|
"""
|
||||||
|
name = gen_collection_name()
|
||||||
|
dim = 128
|
||||||
|
client = self.collection_client
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"dimension": dim,
|
||||||
|
}
|
||||||
|
rsp = client.collection_create(payload)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
rsp = client.collection_list()
|
||||||
|
all_collections = rsp['data']
|
||||||
|
assert name in all_collections
|
||||||
|
# describe collection
|
||||||
|
rsp = client.collection_describe(name)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
assert rsp['data']['collectionName'] == name
|
||||||
|
assert f"FloatVector({dim})" in str(rsp['data']['fields'])
|
||||||
|
|
||||||
|
def test_describe_collections_with_invalid_api_key(self):
|
||||||
|
"""
|
||||||
|
target: test describe collection with invalid api key
|
||||||
|
method: describe collection with invalid api key
|
||||||
|
expected: raise error with right error code and message
|
||||||
|
"""
|
||||||
|
name = gen_collection_name()
|
||||||
|
dim = 128
|
||||||
|
client = self.collection_client
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"dimension": dim,
|
||||||
|
}
|
||||||
|
rsp = client.collection_create(payload)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
rsp = client.collection_list()
|
||||||
|
all_collections = rsp['data']
|
||||||
|
assert name in all_collections
|
||||||
|
# describe collection
|
||||||
|
illegal_client = CollectionClient(self.url, "illegal_api_key")
|
||||||
|
rsp = illegal_client.collection_describe(name)
|
||||||
|
assert rsp['code'] == 1800
|
||||||
|
|
||||||
|
def test_describe_collections_with_invalid_collection_name(self):
|
||||||
|
"""
|
||||||
|
target: test describe collection with invalid collection name
|
||||||
|
method: describe collection with invalid collection name
|
||||||
|
expected: raise error with right error code and message
|
||||||
|
"""
|
||||||
|
name = gen_collection_name()
|
||||||
|
dim = 128
|
||||||
|
client = self.collection_client
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"dimension": dim,
|
||||||
|
}
|
||||||
|
rsp = client.collection_create(payload)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
rsp = client.collection_list()
|
||||||
|
all_collections = rsp['data']
|
||||||
|
assert name in all_collections
|
||||||
|
# describe collection
|
||||||
|
invalid_name = "invalid_name"
|
||||||
|
rsp = client.collection_describe(invalid_name)
|
||||||
|
assert rsp['code'] == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.L0
|
||||||
|
class TestDropCollection(TestBase):
|
||||||
|
|
||||||
|
def test_drop_collections_default(self):
|
||||||
|
"""
|
||||||
|
Drop a collection with a simple schema
|
||||||
|
target: test drop collection with a simple schema
|
||||||
|
method: drop collection
|
||||||
|
expected: dropped collection was not in collection list
|
||||||
|
"""
|
||||||
|
clo_list = []
|
||||||
|
for i in range(5):
|
||||||
|
time.sleep(1)
|
||||||
|
name = 'test_collection_' + datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f_%f")
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"dimension": 128,
|
||||||
|
}
|
||||||
|
rsp = self.collection_client.collection_create(payload)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
clo_list.append(name)
|
||||||
|
rsp = self.collection_client.collection_list()
|
||||||
|
all_collections = rsp['data']
|
||||||
|
for name in clo_list:
|
||||||
|
assert name in all_collections
|
||||||
|
for name in clo_list:
|
||||||
|
time.sleep(0.2)
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
}
|
||||||
|
rsp = self.collection_client.collection_drop(payload)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
rsp = self.collection_client.collection_list()
|
||||||
|
all_collections = rsp['data']
|
||||||
|
for name in clo_list:
|
||||||
|
assert name not in all_collections
|
||||||
|
|
||||||
|
def test_drop_collections_with_invalid_api_key(self):
|
||||||
|
"""
|
||||||
|
target: test drop collection with invalid api key
|
||||||
|
method: drop collection with invalid api key
|
||||||
|
expected: raise error with right error code and message; collection still in collection list
|
||||||
|
"""
|
||||||
|
name = gen_collection_name()
|
||||||
|
dim = 128
|
||||||
|
client = self.collection_client
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"dimension": dim,
|
||||||
|
}
|
||||||
|
rsp = client.collection_create(payload)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
rsp = client.collection_list()
|
||||||
|
all_collections = rsp['data']
|
||||||
|
assert name in all_collections
|
||||||
|
# drop collection
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
}
|
||||||
|
illegal_client = CollectionClient(self.url, "invalid_api_key")
|
||||||
|
rsp = illegal_client.collection_drop(payload)
|
||||||
|
assert rsp['code'] == 1800
|
||||||
|
rsp = client.collection_list()
|
||||||
|
all_collections = rsp['data']
|
||||||
|
assert name in all_collections
|
||||||
|
|
||||||
|
def test_drop_collections_with_invalid_collection_name(self):
|
||||||
|
"""
|
||||||
|
target: test drop collection with invalid collection name
|
||||||
|
method: drop collection with invalid collection name
|
||||||
|
expected: raise error with right error code and message
|
||||||
|
"""
|
||||||
|
name = gen_collection_name()
|
||||||
|
dim = 128
|
||||||
|
client = self.collection_client
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"dimension": dim,
|
||||||
|
}
|
||||||
|
rsp = client.collection_create(payload)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
rsp = client.collection_list()
|
||||||
|
all_collections = rsp['data']
|
||||||
|
assert name in all_collections
|
||||||
|
# drop collection
|
||||||
|
invalid_name = "invalid_name"
|
||||||
|
payload = {
|
||||||
|
"collectionName": invalid_name,
|
||||||
|
}
|
||||||
|
rsp = client.collection_drop(payload)
|
||||||
|
assert rsp['code'] == 100
|
@ -1,49 +0,0 @@
|
|||||||
from time import sleep
|
|
||||||
from common import common_type as ct
|
|
||||||
from common import common_func as cf
|
|
||||||
from base.client_base import TestBase
|
|
||||||
from utils.util_log import test_log as log
|
|
||||||
|
|
||||||
|
|
||||||
class TestDefault(TestBase):
|
|
||||||
|
|
||||||
def test_e2e(self):
|
|
||||||
collection_name, schema = self.init_collection()
|
|
||||||
nb = ct.default_nb
|
|
||||||
# insert
|
|
||||||
res = self.entity_service.insert(collection_name=collection_name, fields_data=cf.gen_fields_data(schema, nb=nb),
|
|
||||||
num_rows=nb)
|
|
||||||
log.info(f"insert {nb} rows into collection {collection_name}, response: {res}")
|
|
||||||
# flush
|
|
||||||
res = self.entity_service.flush(collection_names=[collection_name])
|
|
||||||
log.info(f"flush collection {collection_name}, response: {res}")
|
|
||||||
# create index for vector field
|
|
||||||
vector_field_name = cf.get_vector_field(schema)
|
|
||||||
vector_index_params = cf.gen_index_params(index_type="HNSW")
|
|
||||||
res = self.index_service.create_index(collection_name=collection_name, field_name=vector_field_name,
|
|
||||||
extra_params=vector_index_params)
|
|
||||||
log.info(f"create index for vector field {vector_field_name}, response: {res}")
|
|
||||||
# load
|
|
||||||
res = self.collection_service.load_collection(collection_name=collection_name)
|
|
||||||
log.info(f"load collection {collection_name}, response: {res}")
|
|
||||||
|
|
||||||
sleep(5)
|
|
||||||
# search
|
|
||||||
vectors = cf.gen_vectors(nq=ct.default_nq, schema=schema)
|
|
||||||
res = self.entity_service.search(collection_name=collection_name, vectors=vectors,
|
|
||||||
output_fields=[ct.default_int64_field_name],
|
|
||||||
search_params=cf.gen_search_params())
|
|
||||||
log.info(f"search collection {collection_name}, response: {res}")
|
|
||||||
|
|
||||||
# hybrid search
|
|
||||||
res = self.entity_service.search(collection_name=collection_name, vectors=vectors,
|
|
||||||
output_fields=[ct.default_int64_field_name],
|
|
||||||
search_params=cf.gen_search_params(),
|
|
||||||
dsl=ct.default_dsl)
|
|
||||||
log.info(f"hybrid search collection {collection_name}, response: {res}")
|
|
||||||
# query
|
|
||||||
res = self.entity_service.query(collection_name=collection_name, expr=ct.default_expr)
|
|
||||||
|
|
||||||
log.info(f"query collection {collection_name}, response: {res}")
|
|
||||||
|
|
||||||
|
|
984
tests/restful_client/testcases/test_vector_operations.py
Normal file
984
tests/restful_client/testcases/test_vector_operations.py
Normal file
@ -0,0 +1,984 @@
|
|||||||
|
import datetime
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
from sklearn import preprocessing
|
||||||
|
import numpy as np
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from utils import constant
|
||||||
|
from utils.utils import gen_collection_name
|
||||||
|
from utils.util_log import test_log as logger
|
||||||
|
import pytest
|
||||||
|
from api.milvus import VectorClient
|
||||||
|
from base.testbase import TestBase
|
||||||
|
from utils.utils import (get_data_by_fields, get_data_by_payload, get_common_fields_by_data)
|
||||||
|
|
||||||
|
|
||||||
|
class TestInsertVector(TestBase):
|
||||||
|
|
||||||
|
@pytest.mark.L0
|
||||||
|
@pytest.mark.parametrize("insert_round", [2, 1])
|
||||||
|
@pytest.mark.parametrize("nb", [100, 10, 1])
|
||||||
|
@pytest.mark.parametrize("dim", [32, 128])
|
||||||
|
@pytest.mark.parametrize("primary_field", ["id", "url"])
|
||||||
|
@pytest.mark.parametrize("vector_field", ["vector", "embedding"])
|
||||||
|
@pytest.mark.parametrize("db_name", ["prod", "default"])
|
||||||
|
def test_insert_vector_with_simple_payload(self, db_name, vector_field, primary_field, nb, dim, insert_round):
|
||||||
|
"""
|
||||||
|
Insert a vector with a simple payload
|
||||||
|
"""
|
||||||
|
self.update_database(db_name=db_name)
|
||||||
|
# create a collection
|
||||||
|
name = gen_collection_name()
|
||||||
|
collection_payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"dimension": dim,
|
||||||
|
"primaryField": primary_field,
|
||||||
|
"vectorField": vector_field,
|
||||||
|
}
|
||||||
|
rsp = self.collection_client.collection_create(collection_payload)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
rsp = self.collection_client.collection_describe(name)
|
||||||
|
logger.info(f"rsp: {rsp}")
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
# insert data
|
||||||
|
for i in range(insert_round):
|
||||||
|
data = get_data_by_payload(collection_payload, nb)
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"data": data,
|
||||||
|
}
|
||||||
|
body_size = sys.getsizeof(json.dumps(payload))
|
||||||
|
logger.info(f"body size: {body_size / 1024 / 1024} MB")
|
||||||
|
rsp = self.vector_client.vector_insert(payload)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
assert rsp['data']['insertCount'] == nb
|
||||||
|
logger.info("finished")
|
||||||
|
|
||||||
|
@pytest.mark.L0
|
||||||
|
@pytest.mark.parametrize("insert_round", [10])
|
||||||
|
def test_insert_vector_with_multi_round(self, insert_round):
|
||||||
|
"""
|
||||||
|
Insert a vector with a simple payload
|
||||||
|
"""
|
||||||
|
# create a collection
|
||||||
|
name = gen_collection_name()
|
||||||
|
collection_payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"dimension": 768,
|
||||||
|
}
|
||||||
|
rsp = self.collection_client.collection_create(collection_payload)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
rsp = self.collection_client.collection_describe(name)
|
||||||
|
logger.info(f"rsp: {rsp}")
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
# insert data
|
||||||
|
nb = 300
|
||||||
|
for i in range(insert_round):
|
||||||
|
data = get_data_by_payload(collection_payload, nb)
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"data": data,
|
||||||
|
}
|
||||||
|
body_size = sys.getsizeof(json.dumps(payload))
|
||||||
|
logger.info(f"body size: {body_size / 1024 / 1024} MB")
|
||||||
|
rsp = self.vector_client.vector_insert(payload)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
assert rsp['data']['insertCount'] == nb
|
||||||
|
logger.info("finished")
|
||||||
|
|
||||||
|
def test_insert_vector_with_invalid_api_key(self):
|
||||||
|
"""
|
||||||
|
Insert a vector with invalid api key
|
||||||
|
"""
|
||||||
|
# create a collection
|
||||||
|
name = gen_collection_name()
|
||||||
|
dim = 128
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"dimension": dim,
|
||||||
|
}
|
||||||
|
rsp = self.collection_client.collection_create(payload)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
rsp = self.collection_client.collection_describe(name)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
# insert data
|
||||||
|
nb = 10
|
||||||
|
data = [
|
||||||
|
{
|
||||||
|
"vector": [np.float64(random.random()) for _ in range(dim)],
|
||||||
|
} for _ in range(nb)
|
||||||
|
]
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"data": data,
|
||||||
|
}
|
||||||
|
body_size = sys.getsizeof(json.dumps(payload))
|
||||||
|
logger.info(f"body size: {body_size / 1024 / 1024} MB")
|
||||||
|
client = self.vector_client
|
||||||
|
client.api_key = "invalid_api_key"
|
||||||
|
rsp = client.vector_insert(payload)
|
||||||
|
assert rsp['code'] == 1800
|
||||||
|
|
||||||
|
def test_insert_vector_with_invalid_collection_name(self):
|
||||||
|
"""
|
||||||
|
Insert a vector with an invalid collection name
|
||||||
|
"""
|
||||||
|
|
||||||
|
# create a collection
|
||||||
|
name = gen_collection_name()
|
||||||
|
dim = 128
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"dimension": dim,
|
||||||
|
}
|
||||||
|
rsp = self.collection_client.collection_create(payload)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
rsp = self.collection_client.collection_describe(name)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
# insert data
|
||||||
|
nb = 100
|
||||||
|
data = get_data_by_payload(payload, nb)
|
||||||
|
payload = {
|
||||||
|
"collectionName": "invalid_collection_name",
|
||||||
|
"data": data,
|
||||||
|
}
|
||||||
|
body_size = sys.getsizeof(json.dumps(payload))
|
||||||
|
logger.info(f"body size: {body_size / 1024 / 1024} MB")
|
||||||
|
rsp = self.vector_client.vector_insert(payload)
|
||||||
|
assert rsp['code'] == 1
|
||||||
|
|
||||||
|
def test_insert_vector_with_invalid_database_name(self):
|
||||||
|
"""
|
||||||
|
Insert a vector with an invalid database name
|
||||||
|
"""
|
||||||
|
# create a collection
|
||||||
|
name = gen_collection_name()
|
||||||
|
dim = 128
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"dimension": dim,
|
||||||
|
}
|
||||||
|
rsp = self.collection_client.collection_create(payload)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
rsp = self.collection_client.collection_describe(name)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
# insert data
|
||||||
|
nb = 10
|
||||||
|
data = get_data_by_payload(payload, nb)
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"data": data,
|
||||||
|
}
|
||||||
|
body_size = sys.getsizeof(json.dumps(payload))
|
||||||
|
logger.info(f"body size: {body_size / 1024 / 1024} MB")
|
||||||
|
success = False
|
||||||
|
rsp = self.vector_client.vector_insert(payload, db_name="invalid_database")
|
||||||
|
assert rsp['code'] == 800
|
||||||
|
|
||||||
|
def test_insert_vector_with_mismatch_dim(self):
|
||||||
|
"""
|
||||||
|
Insert a vector with mismatch dim
|
||||||
|
"""
|
||||||
|
# create a collection
|
||||||
|
name = gen_collection_name()
|
||||||
|
dim = 32
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"dimension": dim,
|
||||||
|
}
|
||||||
|
rsp = self.collection_client.collection_create(payload)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
rsp = self.collection_client.collection_describe(name)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
# insert data
|
||||||
|
nb = 1
|
||||||
|
data = [
|
||||||
|
{
|
||||||
|
"vector": [np.float64(random.random()) for _ in range(dim + 1)],
|
||||||
|
} for i in range(nb)
|
||||||
|
]
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"data": data,
|
||||||
|
}
|
||||||
|
body_size = sys.getsizeof(json.dumps(payload))
|
||||||
|
logger.info(f"body size: {body_size / 1024 / 1024} MB")
|
||||||
|
rsp = self.vector_client.vector_insert(payload)
|
||||||
|
assert rsp['code'] == 1804
|
||||||
|
assert rsp['message'] == "fail to deal the insert data"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSearchVector(TestBase):
|
||||||
|
|
||||||
|
@pytest.mark.L0
|
||||||
|
@pytest.mark.parametrize("metric_type", ["IP", "L2"])
|
||||||
|
def test_search_vector_with_simple_payload(self, metric_type):
|
||||||
|
"""
|
||||||
|
Search a vector with a simple payload
|
||||||
|
"""
|
||||||
|
name = gen_collection_name()
|
||||||
|
self.name = name
|
||||||
|
self.init_collection(name, metric_type=metric_type)
|
||||||
|
|
||||||
|
# search data
|
||||||
|
dim = 128
|
||||||
|
vector_to_search = preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist()
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"vector": vector_to_search,
|
||||||
|
}
|
||||||
|
rsp = self.vector_client.vector_search(payload)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
res = rsp['data']
|
||||||
|
logger.info(f"res: {len(res)}")
|
||||||
|
limit = int(payload.get("limit", 100))
|
||||||
|
assert len(res) == limit
|
||||||
|
ids = [item['id'] for item in res]
|
||||||
|
assert len(ids) == len(set(ids))
|
||||||
|
distance = [item['distance'] for item in res]
|
||||||
|
if metric_type == "L2":
|
||||||
|
assert distance == sorted(distance)
|
||||||
|
if metric_type == "IP":
|
||||||
|
assert distance == sorted(distance, reverse=True)
|
||||||
|
|
||||||
|
@pytest.mark.L0
|
||||||
|
@pytest.mark.parametrize("sum_limit_offset", [16384, 16385])
|
||||||
|
def test_search_vector_with_exceed_sum_limit_offset(self, sum_limit_offset):
|
||||||
|
"""
|
||||||
|
Search a vector with a simple payload
|
||||||
|
"""
|
||||||
|
max_search_sum_limit_offset = constant.MAX_SUM_OFFSET_AND_LIMIT
|
||||||
|
name = gen_collection_name()
|
||||||
|
self.name = name
|
||||||
|
nb = sum_limit_offset + 2000
|
||||||
|
metric_type = "IP"
|
||||||
|
limit = 100
|
||||||
|
self.init_collection(name, metric_type=metric_type, nb=nb, batch_size=2000)
|
||||||
|
|
||||||
|
# search data
|
||||||
|
dim = 128
|
||||||
|
vector_to_search = preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist()
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"vector": vector_to_search,
|
||||||
|
"limit": limit,
|
||||||
|
"offset": sum_limit_offset-limit,
|
||||||
|
}
|
||||||
|
rsp = self.vector_client.vector_search(payload)
|
||||||
|
if sum_limit_offset > max_search_sum_limit_offset:
|
||||||
|
assert rsp['code'] == 1
|
||||||
|
return
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
res = rsp['data']
|
||||||
|
logger.info(f"res: {len(res)}")
|
||||||
|
limit = int(payload.get("limit", 100))
|
||||||
|
assert len(res) == limit
|
||||||
|
ids = [item['id'] for item in res]
|
||||||
|
assert len(ids) == len(set(ids))
|
||||||
|
distance = [item['distance'] for item in res]
|
||||||
|
if metric_type == "L2":
|
||||||
|
assert distance == sorted(distance)
|
||||||
|
if metric_type == "IP":
|
||||||
|
assert distance == sorted(distance, reverse=True)
|
||||||
|
|
||||||
|
@pytest.mark.L0
|
||||||
|
@pytest.mark.parametrize("level", [0, 1, 2])
|
||||||
|
@pytest.mark.parametrize("offset", [0, 10, 100])
|
||||||
|
@pytest.mark.parametrize("limit", [1, 100])
|
||||||
|
@pytest.mark.parametrize("metric_type", ["L2", "IP"])
|
||||||
|
def test_search_vector_with_complex_payload(self, limit, offset, level, metric_type):
|
||||||
|
"""
|
||||||
|
Search a vector with a simple payload
|
||||||
|
"""
|
||||||
|
name = gen_collection_name()
|
||||||
|
self.name = name
|
||||||
|
nb = limit + offset + 100
|
||||||
|
dim = 128
|
||||||
|
schema_payload, data = self.init_collection(name, dim=dim, nb=nb, metric_type=metric_type)
|
||||||
|
vector_field = schema_payload.get("vectorField")
|
||||||
|
# search data
|
||||||
|
vector_to_search = preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist()
|
||||||
|
output_fields = get_common_fields_by_data(data, exclude_fields=[vector_field])
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"vector": vector_to_search,
|
||||||
|
"outputFields": output_fields,
|
||||||
|
"filter": "uid >= 0",
|
||||||
|
"limit": limit,
|
||||||
|
"offset": offset,
|
||||||
|
}
|
||||||
|
rsp = self.vector_client.vector_search(payload)
|
||||||
|
if offset + limit > constant.MAX_SUM_OFFSET_AND_LIMIT:
|
||||||
|
assert rsp['code'] == 90126
|
||||||
|
return
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
res = rsp['data']
|
||||||
|
logger.info(f"res: {len(res)}")
|
||||||
|
assert len(res) == limit
|
||||||
|
for item in res:
|
||||||
|
assert item.get("uid") >= 0
|
||||||
|
for field in output_fields:
|
||||||
|
assert field in item
|
||||||
|
|
||||||
|
@pytest.mark.L0
|
||||||
|
@pytest.mark.parametrize("filter_expr", ["uid >= 0", "uid >= 0 and uid < 100", "uid in [1,2,3]"])
|
||||||
|
def test_search_vector_with_complex_int_filter(self, filter_expr):
|
||||||
|
"""
|
||||||
|
Search a vector with a simple payload
|
||||||
|
"""
|
||||||
|
name = gen_collection_name()
|
||||||
|
self.name = name
|
||||||
|
nb = 200
|
||||||
|
dim = 128
|
||||||
|
limit = 100
|
||||||
|
schema_payload, data = self.init_collection(name, dim=dim, nb=nb)
|
||||||
|
vector_field = schema_payload.get("vectorField")
|
||||||
|
# search data
|
||||||
|
vector_to_search = preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist()
|
||||||
|
output_fields = get_common_fields_by_data(data, exclude_fields=[vector_field])
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"vector": vector_to_search,
|
||||||
|
"outputFields": output_fields,
|
||||||
|
"filter": filter_expr,
|
||||||
|
"limit": limit,
|
||||||
|
"offset": 0,
|
||||||
|
}
|
||||||
|
rsp = self.vector_client.vector_search(payload)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
res = rsp['data']
|
||||||
|
logger.info(f"res: {len(res)}")
|
||||||
|
assert len(res) <= limit
|
||||||
|
for item in res:
|
||||||
|
uid = item.get("uid")
|
||||||
|
eval(filter_expr)
|
||||||
|
|
||||||
|
@pytest.mark.L0
|
||||||
|
@pytest.mark.parametrize("filter_expr", ["name > \"placeholder\"", "name like \"placeholder%\""])
|
||||||
|
def test_search_vector_with_complex_varchar_filter(self, filter_expr):
|
||||||
|
"""
|
||||||
|
Search a vector with a simple payload
|
||||||
|
"""
|
||||||
|
name = gen_collection_name()
|
||||||
|
self.name = name
|
||||||
|
nb = 200
|
||||||
|
dim = 128
|
||||||
|
limit = 100
|
||||||
|
schema_payload, data = self.init_collection(name, dim=dim, nb=nb)
|
||||||
|
names = []
|
||||||
|
for item in data:
|
||||||
|
names.append(item.get("name"))
|
||||||
|
names.sort()
|
||||||
|
logger.info(f"names: {names}")
|
||||||
|
mid = len(names) // 2
|
||||||
|
prefix = names[mid][0:2]
|
||||||
|
vector_field = schema_payload.get("vectorField")
|
||||||
|
# search data
|
||||||
|
vector_to_search = preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist()
|
||||||
|
output_fields = get_common_fields_by_data(data, exclude_fields=[vector_field])
|
||||||
|
filter_expr = filter_expr.replace("placeholder", prefix)
|
||||||
|
logger.info(f"filter_expr: {filter_expr}")
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"vector": vector_to_search,
|
||||||
|
"outputFields": output_fields,
|
||||||
|
"filter": filter_expr,
|
||||||
|
"limit": limit,
|
||||||
|
"offset": 0,
|
||||||
|
}
|
||||||
|
rsp = self.vector_client.vector_search(payload)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
res = rsp['data']
|
||||||
|
logger.info(f"res: {len(res)}")
|
||||||
|
assert len(res) <= limit
|
||||||
|
for item in res:
|
||||||
|
name = item.get("name")
|
||||||
|
logger.info(f"name: {name}")
|
||||||
|
if ">" in filter_expr:
|
||||||
|
assert name > prefix
|
||||||
|
if "like" in filter_expr:
|
||||||
|
assert name.startswith(prefix)
|
||||||
|
|
||||||
|
@pytest.mark.L0
|
||||||
|
@pytest.mark.parametrize("filter_expr", ["uid < 100 and name > \"placeholder\"",
|
||||||
|
"uid < 100 and name like \"placeholder%\""
|
||||||
|
])
|
||||||
|
def test_search_vector_with_complex_int64_varchar_and_filter(self, filter_expr):
|
||||||
|
"""
|
||||||
|
Search a vector with a simple payload
|
||||||
|
"""
|
||||||
|
name = gen_collection_name()
|
||||||
|
self.name = name
|
||||||
|
nb = 200
|
||||||
|
dim = 128
|
||||||
|
limit = 100
|
||||||
|
schema_payload, data = self.init_collection(name, dim=dim, nb=nb)
|
||||||
|
names = []
|
||||||
|
for item in data:
|
||||||
|
names.append(item.get("name"))
|
||||||
|
names.sort()
|
||||||
|
logger.info(f"names: {names}")
|
||||||
|
mid = len(names) // 2
|
||||||
|
prefix = names[mid][0:2]
|
||||||
|
vector_field = schema_payload.get("vectorField")
|
||||||
|
# search data
|
||||||
|
vector_to_search = preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist()
|
||||||
|
output_fields = get_common_fields_by_data(data, exclude_fields=[vector_field])
|
||||||
|
filter_expr = filter_expr.replace("placeholder", prefix)
|
||||||
|
logger.info(f"filter_expr: {filter_expr}")
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"vector": vector_to_search,
|
||||||
|
"outputFields": output_fields,
|
||||||
|
"filter": filter_expr,
|
||||||
|
"limit": limit,
|
||||||
|
"offset": 0,
|
||||||
|
}
|
||||||
|
rsp = self.vector_client.vector_search(payload)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
res = rsp['data']
|
||||||
|
logger.info(f"res: {len(res)}")
|
||||||
|
assert len(res) <= limit
|
||||||
|
for item in res:
|
||||||
|
uid = item.get("uid")
|
||||||
|
name = item.get("name")
|
||||||
|
logger.info(f"name: {name}")
|
||||||
|
uid_expr = filter_expr.split("and")[0]
|
||||||
|
assert eval(uid_expr) is True
|
||||||
|
varchar_expr = filter_expr.split("and")[1]
|
||||||
|
if ">" in varchar_expr:
|
||||||
|
assert name > prefix
|
||||||
|
if "like" in varchar_expr:
|
||||||
|
assert name.startswith(prefix)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("limit", [0, 16385])
|
||||||
|
def test_search_vector_with_invalid_limit(self, limit):
|
||||||
|
"""
|
||||||
|
Search a vector with a simple payload
|
||||||
|
"""
|
||||||
|
name = gen_collection_name()
|
||||||
|
self.name = name
|
||||||
|
dim = 128
|
||||||
|
schema_payload, data = self.init_collection(name, dim=dim)
|
||||||
|
vector_field = schema_payload.get("vectorField")
|
||||||
|
# search data
|
||||||
|
vector_to_search = preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist()
|
||||||
|
output_fields = get_common_fields_by_data(data, exclude_fields=[vector_field])
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"vector": vector_to_search,
|
||||||
|
"outputFields": output_fields,
|
||||||
|
"filter": "uid >= 0",
|
||||||
|
"limit": limit,
|
||||||
|
"offset": 0,
|
||||||
|
}
|
||||||
|
rsp = self.vector_client.vector_search(payload)
|
||||||
|
assert rsp['code'] == 1
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("offset", [-1, 100_001])
|
||||||
|
def test_search_vector_with_invalid_offset(self, offset):
|
||||||
|
"""
|
||||||
|
Search a vector with a simple payload
|
||||||
|
"""
|
||||||
|
name = gen_collection_name()
|
||||||
|
self.name = name
|
||||||
|
dim = 128
|
||||||
|
schema_payload, data = self.init_collection(name, dim=dim)
|
||||||
|
vector_field = schema_payload.get("vectorField")
|
||||||
|
# search data
|
||||||
|
dim = 128
|
||||||
|
vector_to_search = preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist()
|
||||||
|
output_fields = get_common_fields_by_data(data, exclude_fields=[vector_field])
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"vector": vector_to_search,
|
||||||
|
"outputFields": output_fields,
|
||||||
|
"filter": "uid >= 0",
|
||||||
|
"limit": 100,
|
||||||
|
"offset": offset,
|
||||||
|
}
|
||||||
|
rsp = self.vector_client.vector_search(payload)
|
||||||
|
assert rsp['code'] == 1
|
||||||
|
|
||||||
|
def test_search_vector_with_illegal_api_key(self):
|
||||||
|
"""
|
||||||
|
Search a vector with an illegal api key
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_search_vector_with_invalid_collection_name(self):
|
||||||
|
"""
|
||||||
|
Search a vector with an invalid collection name
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_search_vector_with_invalid_output_field(self):
|
||||||
|
"""
|
||||||
|
Search a vector with an invalid output field
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("invalid_expr", ["invalid_field > 0", "12-s", "中文", "a", " "])
|
||||||
|
def test_search_vector_with_invalid_expression(self, invalid_expr):
|
||||||
|
"""
|
||||||
|
Search a vector with an invalid expression
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_search_vector_with_invalid_vector_field(self):
|
||||||
|
"""
|
||||||
|
Search a vector with an invalid vector field for ann search
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dim_offset", [1, -1])
|
||||||
|
def test_search_vector_with_mismatch_vector_dim(self, dim_offset):
|
||||||
|
"""
|
||||||
|
Search a vector with a mismatch vector dim
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestQueryVector(TestBase):
|
||||||
|
|
||||||
|
@pytest.mark.L0
|
||||||
|
@pytest.mark.parametrize("expr", ["10+20 <= uid < 20+30", "uid in [1,2,3,4]",
|
||||||
|
"uid > 0", "uid >= 0", "uid > 0",
|
||||||
|
"uid > -100 and uid < 100"])
|
||||||
|
@pytest.mark.parametrize("include_output_fields", [True, False])
|
||||||
|
@pytest.mark.parametrize("partial_fields", [True, False])
|
||||||
|
def test_query_vector_with_int64_filter(self, expr, include_output_fields, partial_fields):
|
||||||
|
"""
|
||||||
|
Query a vector with a simple payload
|
||||||
|
"""
|
||||||
|
name = gen_collection_name()
|
||||||
|
self.name = name
|
||||||
|
schema_payload, data = self.init_collection(name)
|
||||||
|
output_fields = get_common_fields_by_data(data)
|
||||||
|
if partial_fields:
|
||||||
|
output_fields = output_fields[:len(output_fields) // 2]
|
||||||
|
if "uid" not in output_fields:
|
||||||
|
output_fields.append("uid")
|
||||||
|
else:
|
||||||
|
output_fields = output_fields
|
||||||
|
|
||||||
|
# query data
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"filter": expr,
|
||||||
|
"limit": 100,
|
||||||
|
"offset": 0,
|
||||||
|
"outputFields": output_fields
|
||||||
|
}
|
||||||
|
if not include_output_fields:
|
||||||
|
payload.pop("outputFields")
|
||||||
|
if 'vector' in output_fields:
|
||||||
|
output_fields.remove("vector")
|
||||||
|
time.sleep(5)
|
||||||
|
rsp = self.vector_client.vector_query(payload)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
res = rsp['data']
|
||||||
|
logger.info(f"res: {len(res)}")
|
||||||
|
for r in res:
|
||||||
|
uid = r['uid']
|
||||||
|
assert eval(expr) is True
|
||||||
|
for field in output_fields:
|
||||||
|
assert field in r
|
||||||
|
|
||||||
|
@pytest.mark.L0
|
||||||
|
@pytest.mark.parametrize("filter_expr", ["name > \"placeholder\"", "name like \"placeholder%\""])
|
||||||
|
@pytest.mark.parametrize("include_output_fields", [True, False])
|
||||||
|
def test_query_vector_with_varchar_filter(self, filter_expr, include_output_fields):
|
||||||
|
"""
|
||||||
|
Query a vector with a complex payload
|
||||||
|
"""
|
||||||
|
name = gen_collection_name()
|
||||||
|
self.name = name
|
||||||
|
nb = 200
|
||||||
|
dim = 128
|
||||||
|
limit = 100
|
||||||
|
schema_payload, data = self.init_collection(name, dim=dim, nb=nb)
|
||||||
|
names = []
|
||||||
|
for item in data:
|
||||||
|
names.append(item.get("name"))
|
||||||
|
names.sort()
|
||||||
|
logger.info(f"names: {names}")
|
||||||
|
mid = len(names) // 2
|
||||||
|
prefix = names[mid][0:2]
|
||||||
|
# search data
|
||||||
|
output_fields = get_common_fields_by_data(data)
|
||||||
|
filter_expr = filter_expr.replace("placeholder", prefix)
|
||||||
|
logger.info(f"filter_expr: {filter_expr}")
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"outputFields": output_fields,
|
||||||
|
"filter": filter_expr,
|
||||||
|
"limit": limit,
|
||||||
|
"offset": 0,
|
||||||
|
}
|
||||||
|
if not include_output_fields:
|
||||||
|
payload.pop("outputFields")
|
||||||
|
rsp = self.vector_client.vector_query(payload)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
res = rsp['data']
|
||||||
|
logger.info(f"res: {len(res)}")
|
||||||
|
assert len(res) <= limit
|
||||||
|
for item in res:
|
||||||
|
name = item.get("name")
|
||||||
|
logger.info(f"name: {name}")
|
||||||
|
if ">" in filter_expr:
|
||||||
|
assert name > prefix
|
||||||
|
if "like" in filter_expr:
|
||||||
|
assert name.startswith(prefix)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sum_of_limit_offset", [16384, 16385])
|
||||||
|
def test_query_vector_with_large_sum_of_limit_offset(self, sum_of_limit_offset):
|
||||||
|
"""
|
||||||
|
Query a vector with sum of limit and offset larger than max value
|
||||||
|
"""
|
||||||
|
max_sum_of_limit_offset = 16384
|
||||||
|
name = gen_collection_name()
|
||||||
|
filter_expr = "name > \"placeholder\""
|
||||||
|
self.name = name
|
||||||
|
nb = 200
|
||||||
|
dim = 128
|
||||||
|
limit = 100
|
||||||
|
offset = sum_of_limit_offset - limit
|
||||||
|
schema_payload, data = self.init_collection(name, dim=dim, nb=nb)
|
||||||
|
names = []
|
||||||
|
for item in data:
|
||||||
|
names.append(item.get("name"))
|
||||||
|
names.sort()
|
||||||
|
logger.info(f"names: {names}")
|
||||||
|
mid = len(names) // 2
|
||||||
|
prefix = names[mid][0:2]
|
||||||
|
# search data
|
||||||
|
output_fields = get_common_fields_by_data(data)
|
||||||
|
filter_expr = filter_expr.replace("placeholder", prefix)
|
||||||
|
logger.info(f"filter_expr: {filter_expr}")
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"outputFields": output_fields,
|
||||||
|
"filter": filter_expr,
|
||||||
|
"limit": limit,
|
||||||
|
"offset": offset,
|
||||||
|
}
|
||||||
|
rsp = self.vector_client.vector_query(payload)
|
||||||
|
if sum_of_limit_offset > max_sum_of_limit_offset:
|
||||||
|
assert rsp['code'] == 1
|
||||||
|
return
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
res = rsp['data']
|
||||||
|
logger.info(f"res: {len(res)}")
|
||||||
|
assert len(res) <= limit
|
||||||
|
for item in res:
|
||||||
|
name = item.get("name")
|
||||||
|
logger.info(f"name: {name}")
|
||||||
|
if ">" in filter_expr:
|
||||||
|
assert name > prefix
|
||||||
|
if "like" in filter_expr:
|
||||||
|
assert name.startswith(prefix)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetVector(TestBase):
|
||||||
|
|
||||||
|
@pytest.mark.L0
|
||||||
|
def test_get_vector_with_simple_payload(self):
|
||||||
|
"""
|
||||||
|
Search a vector with a simple payload
|
||||||
|
"""
|
||||||
|
name = gen_collection_name()
|
||||||
|
self.name = name
|
||||||
|
self.init_collection(name)
|
||||||
|
|
||||||
|
# search data
|
||||||
|
dim = 128
|
||||||
|
vector_to_search = preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist()
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"vector": vector_to_search,
|
||||||
|
}
|
||||||
|
rsp = self.vector_client.vector_search(payload)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
res = rsp['data']
|
||||||
|
logger.info(f"res: {len(res)}")
|
||||||
|
limit = int(payload.get("limit", 100))
|
||||||
|
assert len(res) == limit
|
||||||
|
ids = [item['id'] for item in res]
|
||||||
|
assert len(ids) == len(set(ids))
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"outputFields": ["*"],
|
||||||
|
"id": ids[0],
|
||||||
|
}
|
||||||
|
rsp = self.vector_client.vector_get(payload)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
res = rsp['data']
|
||||||
|
logger.info(f"res: {res}")
|
||||||
|
logger.info(f"res: {len(res)}")
|
||||||
|
for item in res:
|
||||||
|
assert item['id'] == ids[0]
|
||||||
|
|
||||||
|
@pytest.mark.L0
|
||||||
|
@pytest.mark.parametrize("id_field_type", ["list", "one"])
|
||||||
|
@pytest.mark.parametrize("include_invalid_id", [True, False])
|
||||||
|
@pytest.mark.parametrize("include_output_fields", [True, False])
|
||||||
|
def test_get_vector_complex(self, id_field_type, include_output_fields, include_invalid_id):
|
||||||
|
name = gen_collection_name()
|
||||||
|
self.name = name
|
||||||
|
nb = 200
|
||||||
|
dim = 128
|
||||||
|
schema_payload, data = self.init_collection(name, dim=dim, nb=nb)
|
||||||
|
output_fields = get_common_fields_by_data(data)
|
||||||
|
uids = []
|
||||||
|
for item in data:
|
||||||
|
uids.append(item.get("uid"))
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"outputFields": output_fields,
|
||||||
|
"filter": f"uid in {uids}",
|
||||||
|
}
|
||||||
|
rsp = self.vector_client.vector_query(payload)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
res = rsp['data']
|
||||||
|
logger.info(f"res: {len(res)}")
|
||||||
|
ids = []
|
||||||
|
for r in res:
|
||||||
|
ids.append(r['id'])
|
||||||
|
logger.info(f"ids: {len(ids)}")
|
||||||
|
id_to_get = None
|
||||||
|
if id_field_type == "list":
|
||||||
|
id_to_get = ids
|
||||||
|
if id_field_type == "one":
|
||||||
|
id_to_get = ids[0]
|
||||||
|
if include_invalid_id:
|
||||||
|
if isinstance(id_to_get, list):
|
||||||
|
id_to_get[-1] = 0
|
||||||
|
else:
|
||||||
|
id_to_get = 0
|
||||||
|
# get by id list
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"outputFields": output_fields,
|
||||||
|
"id": id_to_get
|
||||||
|
}
|
||||||
|
rsp = self.vector_client.vector_get(payload)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
res = rsp['data']
|
||||||
|
if isinstance(id_to_get, list):
|
||||||
|
if include_invalid_id:
|
||||||
|
assert len(res) == len(id_to_get) - 1
|
||||||
|
else:
|
||||||
|
assert len(res) == len(id_to_get)
|
||||||
|
else:
|
||||||
|
if include_invalid_id:
|
||||||
|
assert len(res) == 0
|
||||||
|
else:
|
||||||
|
assert len(res) == 1
|
||||||
|
for r in rsp['data']:
|
||||||
|
if isinstance(id_to_get, list):
|
||||||
|
assert r['id'] in id_to_get
|
||||||
|
else:
|
||||||
|
assert r['id'] == id_to_get
|
||||||
|
if include_output_fields:
|
||||||
|
for field in output_fields:
|
||||||
|
assert field in r
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeleteVector(TestBase):
|
||||||
|
|
||||||
|
@pytest.mark.L0
|
||||||
|
@pytest.mark.parametrize("include_invalid_id", [True, False])
|
||||||
|
@pytest.mark.parametrize("id_field_type", ["list", "one"])
|
||||||
|
def test_delete_vector_default(self, id_field_type, include_invalid_id):
|
||||||
|
name = gen_collection_name()
|
||||||
|
self.name = name
|
||||||
|
nb = 200
|
||||||
|
dim = 128
|
||||||
|
schema_payload, data = self.init_collection(name, dim=dim, nb=nb)
|
||||||
|
time.sleep(1)
|
||||||
|
output_fields = get_common_fields_by_data(data)
|
||||||
|
uids = []
|
||||||
|
for item in data:
|
||||||
|
uids.append(item.get("uid"))
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"outputFields": output_fields,
|
||||||
|
"filter": f"uid in {uids}",
|
||||||
|
}
|
||||||
|
rsp = self.vector_client.vector_query(payload)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
res = rsp['data']
|
||||||
|
logger.info(f"res: {len(res)}")
|
||||||
|
ids = []
|
||||||
|
for r in res:
|
||||||
|
ids.append(r['id'])
|
||||||
|
logger.info(f"ids: {len(ids)}")
|
||||||
|
id_to_get = None
|
||||||
|
if id_field_type == "list":
|
||||||
|
id_to_get = ids
|
||||||
|
if id_field_type == "one":
|
||||||
|
id_to_get = ids[0]
|
||||||
|
if include_invalid_id:
|
||||||
|
if isinstance(id_to_get, list):
|
||||||
|
id_to_get.append(0)
|
||||||
|
else:
|
||||||
|
id_to_get = 0
|
||||||
|
if isinstance(id_to_get, list):
|
||||||
|
if len(id_to_get) >= 100:
|
||||||
|
id_to_get = id_to_get[-100:]
|
||||||
|
# delete by id list
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"id": id_to_get
|
||||||
|
}
|
||||||
|
rsp = self.vector_client.vector_delete(payload)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
logger.info(f"delete res: {rsp}")
|
||||||
|
|
||||||
|
# verify data deleted
|
||||||
|
if not isinstance(id_to_get, list):
|
||||||
|
id_to_get = [id_to_get]
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"filter": f"id in {id_to_get}",
|
||||||
|
}
|
||||||
|
time.sleep(5)
|
||||||
|
rsp = self.vector_client.vector_query(payload)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
assert len(rsp['data']) == 0
|
||||||
|
|
||||||
|
def test_delete_vector_with_invalid_api_key(self):
|
||||||
|
"""
|
||||||
|
Delete a vector with an invalid api key
|
||||||
|
"""
|
||||||
|
name = gen_collection_name()
|
||||||
|
self.name = name
|
||||||
|
nb = 200
|
||||||
|
dim = 128
|
||||||
|
schema_payload, data = self.init_collection(name, dim=dim, nb=nb)
|
||||||
|
output_fields = get_common_fields_by_data(data)
|
||||||
|
uids = []
|
||||||
|
for item in data:
|
||||||
|
uids.append(item.get("uid"))
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"outputFields": output_fields,
|
||||||
|
"filter": f"uid in {uids}",
|
||||||
|
}
|
||||||
|
rsp = self.vector_client.vector_query(payload)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
res = rsp['data']
|
||||||
|
logger.info(f"res: {len(res)}")
|
||||||
|
ids = []
|
||||||
|
for r in res:
|
||||||
|
ids.append(r['id'])
|
||||||
|
logger.info(f"ids: {len(ids)}")
|
||||||
|
id_to_get = ids
|
||||||
|
# delete by id list
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"id": id_to_get
|
||||||
|
}
|
||||||
|
client = self.vector_client
|
||||||
|
client.api_key = "invalid_api_key"
|
||||||
|
rsp = client.vector_delete(payload)
|
||||||
|
assert rsp['code'] == 1800
|
||||||
|
|
||||||
|
def test_delete_vector_with_invalid_collection_name(self):
|
||||||
|
"""
|
||||||
|
Delete a vector with an invalid collection name
|
||||||
|
"""
|
||||||
|
name = gen_collection_name()
|
||||||
|
self.name = name
|
||||||
|
self.init_collection(name, dim=128, nb=3000)
|
||||||
|
|
||||||
|
# query data
|
||||||
|
# expr = f"id in {[i for i in range(10)]}".replace("[", "(").replace("]", ")")
|
||||||
|
expr = "id > 0"
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"filter": expr,
|
||||||
|
"limit": 3000,
|
||||||
|
"offset": 0,
|
||||||
|
"outputFields": ["id", "uid"]
|
||||||
|
}
|
||||||
|
rsp = self.vector_client.vector_query(payload)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
res = rsp['data']
|
||||||
|
logger.info(f"res: {len(res)}")
|
||||||
|
id_list = [r['id'] for r in res]
|
||||||
|
delete_expr = f"id in {[i for i in id_list[:10]]}"
|
||||||
|
# query data before delete
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"filter": delete_expr,
|
||||||
|
"limit": 3000,
|
||||||
|
"offset": 0,
|
||||||
|
"outputFields": ["id", "uid"]
|
||||||
|
}
|
||||||
|
rsp = self.vector_client.vector_query(payload)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
res = rsp['data']
|
||||||
|
logger.info(f"res: {len(res)}")
|
||||||
|
# delete data
|
||||||
|
payload = {
|
||||||
|
"collectionName": name + "_invalid",
|
||||||
|
"filter": delete_expr,
|
||||||
|
}
|
||||||
|
rsp = self.vector_client.vector_delete(payload)
|
||||||
|
assert rsp['code'] == 1
|
||||||
|
|
||||||
|
def test_delete_vector_with_non_primary_key(self):
|
||||||
|
"""
|
||||||
|
Delete a vector with a non-primary key, expect no data were deleted
|
||||||
|
"""
|
||||||
|
name = gen_collection_name()
|
||||||
|
self.name = name
|
||||||
|
self.init_collection(name, dim=128, nb=300)
|
||||||
|
expr = "uid > 0"
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"filter": expr,
|
||||||
|
"limit": 3000,
|
||||||
|
"offset": 0,
|
||||||
|
"outputFields": ["id", "uid"]
|
||||||
|
}
|
||||||
|
rsp = self.vector_client.vector_query(payload)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
res = rsp['data']
|
||||||
|
logger.info(f"res: {len(res)}")
|
||||||
|
id_list = [r['uid'] for r in res]
|
||||||
|
delete_expr = f"uid in {[i for i in id_list[:10]]}"
|
||||||
|
# query data before delete
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"filter": delete_expr,
|
||||||
|
"limit": 3000,
|
||||||
|
"offset": 0,
|
||||||
|
"outputFields": ["id", "uid"]
|
||||||
|
}
|
||||||
|
rsp = self.vector_client.vector_query(payload)
|
||||||
|
assert rsp['code'] == 200
|
||||||
|
res = rsp['data']
|
||||||
|
num_before_delete = len(res)
|
||||||
|
logger.info(f"res: {len(res)}")
|
||||||
|
# delete data
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"filter": delete_expr,
|
||||||
|
}
|
||||||
|
rsp = self.vector_client.vector_delete(payload)
|
||||||
|
# query data after delete
|
||||||
|
payload = {
|
||||||
|
"collectionName": name,
|
||||||
|
"filter": delete_expr,
|
||||||
|
"limit": 3000,
|
||||||
|
"offset": 0,
|
||||||
|
"outputFields": ["id", "uid"]
|
||||||
|
}
|
||||||
|
time.sleep(1)
|
||||||
|
rsp = self.vector_client.vector_query(payload)
|
||||||
|
assert len(rsp["data"]) == num_before_delete
|
2
tests/restful_client/utils/constant.py
Normal file
2
tests/restful_client/utils/constant.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
|
||||||
|
MAX_SUM_OFFSET_AND_LIMIT = 16384
|
@ -1,4 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from loguru import logger as loguru_logger
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from config.log_config import log_config
|
from config.log_config import log_config
|
||||||
@ -54,4 +55,6 @@ log_debug = log_config.log_debug
|
|||||||
log_info = log_config.log_info
|
log_info = log_config.log_info
|
||||||
log_err = log_config.log_err
|
log_err = log_config.log_err
|
||||||
log_worker = log_config.log_worker
|
log_worker = log_config.log_worker
|
||||||
test_log = TestLog('ci_test', log_debug, log_info, log_err, log_worker).log
|
self_defined_log = TestLog('ci_test', log_debug, log_info, log_err, log_worker).log
|
||||||
|
loguru_log = loguru_logger
|
||||||
|
test_log = self_defined_log
|
||||||
|
@ -1,38 +0,0 @@
|
|||||||
import time
|
|
||||||
from datetime import datetime
|
|
||||||
import functools
|
|
||||||
from utils.util_log import test_log as log
|
|
||||||
|
|
||||||
DEFAULT_FMT = '[{start_time}] [{elapsed:0.8f}s] {collection_name} {func_name} -> {res!r}'
|
|
||||||
|
|
||||||
|
|
||||||
def trace(fmt=DEFAULT_FMT, prefix='test', flag=True):
|
|
||||||
def decorate(func):
|
|
||||||
@functools.wraps(func)
|
|
||||||
def inner_wrapper(*args, **kwargs):
|
|
||||||
# args[0] is an instance of ApiCollectionWrapper class
|
|
||||||
flag = args[0].active_trace
|
|
||||||
if flag:
|
|
||||||
start_time = datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%SZ')
|
|
||||||
t0 = time.perf_counter()
|
|
||||||
res, result = func(*args, **kwargs)
|
|
||||||
elapsed = time.perf_counter() - t0
|
|
||||||
end_time = datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%SZ')
|
|
||||||
func_name = func.__name__
|
|
||||||
collection_name = args[0].collection.name
|
|
||||||
# arg_lst = [repr(arg) for arg in args[1:]][:100]
|
|
||||||
# arg_lst.extend(f'{k}={v!r}' for k, v in kwargs.items())
|
|
||||||
# arg_str = ', '.join(arg_lst)[:200]
|
|
||||||
|
|
||||||
log_str = f"[{prefix}]" + fmt.format(**locals())
|
|
||||||
# TODO: add report function in this place, like uploading to influxdb
|
|
||||||
# it is better a async way to do this, in case of blocking the request processing
|
|
||||||
log.info(log_str)
|
|
||||||
return res, result
|
|
||||||
else:
|
|
||||||
res, result = func(*args, **kwargs)
|
|
||||||
return res, result
|
|
||||||
|
|
||||||
return inner_wrapper
|
|
||||||
|
|
||||||
return decorate
|
|
155
tests/restful_client/utils/utils.py
Normal file
155
tests/restful_client/utils/utils.py
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
import random
|
||||||
|
import time
|
||||||
|
import random
|
||||||
|
import string
|
||||||
|
from faker import Faker
|
||||||
|
import numpy as np
|
||||||
|
from sklearn import preprocessing
|
||||||
|
import requests
|
||||||
|
from loguru import logger
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
fake = Faker()
|
||||||
|
|
||||||
|
|
||||||
|
def random_string(length=8):
|
||||||
|
letters = string.ascii_letters
|
||||||
|
return ''.join(random.choice(letters) for _ in range(length))
|
||||||
|
|
||||||
|
|
||||||
|
def gen_collection_name(prefix="test_collection", length=8):
|
||||||
|
name = f'{prefix}_' + datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f") + random_string(length=length)
|
||||||
|
return name
|
||||||
|
|
||||||
|
def admin_password():
|
||||||
|
return "Milvus"
|
||||||
|
|
||||||
|
|
||||||
|
def invalid_cluster_name():
|
||||||
|
res = [
|
||||||
|
"demo" * 100,
|
||||||
|
"demo" + "!",
|
||||||
|
"demo" + "@",
|
||||||
|
]
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def wait_cluster_be_ready(cluster_id, client, timeout=120):
|
||||||
|
t0 = time.time()
|
||||||
|
while True and time.time() - t0 < timeout:
|
||||||
|
rsp = client.cluster_describe(cluster_id)
|
||||||
|
if rsp['code'] == 200:
|
||||||
|
if rsp['data']['status'] == "RUNNING":
|
||||||
|
return time.time() - t0
|
||||||
|
time.sleep(1)
|
||||||
|
logger.debug("wait cluster to be ready, cost time: %s" % (time.time() - t0))
|
||||||
|
return -1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def gen_data_by_type(field):
|
||||||
|
data_type = field["type"]
|
||||||
|
if data_type == "bool":
|
||||||
|
return random.choice([True, False])
|
||||||
|
if data_type == "int8":
|
||||||
|
return random.randint(-128, 127)
|
||||||
|
if data_type == "int16":
|
||||||
|
return random.randint(-32768, 32767)
|
||||||
|
if data_type == "int32":
|
||||||
|
return random.randint(-2147483648, 2147483647)
|
||||||
|
if data_type == "int64":
|
||||||
|
return random.randint(-9223372036854775808, 9223372036854775807)
|
||||||
|
if data_type == "float32":
|
||||||
|
return np.float64(random.random()) # Object of type float32 is not JSON serializable, so set it as float64
|
||||||
|
if data_type == "float64":
|
||||||
|
return np.float64(random.random())
|
||||||
|
if "varchar" in data_type:
|
||||||
|
length = int(data_type.split("(")[1].split(")")[0])
|
||||||
|
return "".join([chr(random.randint(97, 122)) for _ in range(length)])
|
||||||
|
if "floatVector" in data_type:
|
||||||
|
dim = int(data_type.split("(")[1].split(")")[0])
|
||||||
|
return preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist()
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_data_by_fields(fields, nb):
|
||||||
|
# logger.info(f"fields: {fields}")
|
||||||
|
fields_not_auto_id = []
|
||||||
|
for field in fields:
|
||||||
|
if not field.get("autoId", False):
|
||||||
|
fields_not_auto_id.append(field)
|
||||||
|
# logger.info(f"fields_not_auto_id: {fields_not_auto_id}")
|
||||||
|
data = []
|
||||||
|
for i in range(nb):
|
||||||
|
tmp = {}
|
||||||
|
for field in fields_not_auto_id:
|
||||||
|
tmp[field["name"]] = gen_data_by_type(field)
|
||||||
|
data.append(tmp)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def get_random_json_data(uid=None):
|
||||||
|
# gen random dict data
|
||||||
|
if uid is None:
|
||||||
|
uid = 0
|
||||||
|
data = {"uid": uid, "name": fake.name(), "address": fake.address(), "text": fake.text(), "email": fake.email(),
|
||||||
|
"phone_number": fake.phone_number(),
|
||||||
|
"json": {
|
||||||
|
"name": fake.name(),
|
||||||
|
"address": fake.address()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for i in range(random.randint(1, 10)):
|
||||||
|
data["key" + str(random.randint(1, 100_000))] = "value" + str(random.randint(1, 100_000))
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def get_data_by_payload(payload, nb=100):
|
||||||
|
dim = payload.get("dimension", 128)
|
||||||
|
vector_field = payload.get("vectorField", "vector")
|
||||||
|
data = []
|
||||||
|
if nb == 1:
|
||||||
|
data = [{
|
||||||
|
vector_field: preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist(),
|
||||||
|
**get_random_json_data()
|
||||||
|
|
||||||
|
}]
|
||||||
|
else:
|
||||||
|
for i in range(nb):
|
||||||
|
data.append({
|
||||||
|
vector_field: preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist(),
|
||||||
|
**get_random_json_data(uid=i)
|
||||||
|
})
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def get_common_fields_by_data(data, exclude_fields=None):
|
||||||
|
fields = set()
|
||||||
|
if isinstance(data, dict):
|
||||||
|
data = [data]
|
||||||
|
if not isinstance(data, list):
|
||||||
|
raise Exception("data must be list or dict")
|
||||||
|
common_fields = set(data[0].keys())
|
||||||
|
for d in data:
|
||||||
|
keys = set(d.keys())
|
||||||
|
common_fields = common_fields.intersection(keys)
|
||||||
|
if exclude_fields is not None:
|
||||||
|
exclude_fields = set(exclude_fields)
|
||||||
|
common_fields = common_fields.difference(exclude_fields)
|
||||||
|
return list(common_fields)
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_fields_by_data(data, exclude_fields=None):
|
||||||
|
fields = set()
|
||||||
|
for d in data:
|
||||||
|
keys = list(d.keys())
|
||||||
|
fields.union(keys)
|
||||||
|
if exclude_fields is not None:
|
||||||
|
exclude_fields = set(exclude_fields)
|
||||||
|
fields = fields.difference(exclude_fields)
|
||||||
|
return list(fields)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user