mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-03 04:19:18 +08:00
feat(python): complete sdk 0.0.1
Former-commit-id: 713b3a629fcf86291d1e35c1d1b038b6fc599cb8
This commit is contained in:
parent
964d3f376e
commit
508131e4bc
@ -1,70 +1,10 @@
|
||||
from enum import IntEnum
|
||||
from .Exceptions import ConnectParamMissingError
|
||||
|
||||
|
||||
class AbstactIndexType(object):
|
||||
RAW = '1'
|
||||
IVFFLAT = '2'
|
||||
|
||||
|
||||
class AbstractColumnType(object):
|
||||
INVALID = 1
|
||||
INT8 = 2
|
||||
INT16 = 3
|
||||
INT32 = 4
|
||||
INT64 = 5
|
||||
FLOAT32 = 6
|
||||
FLOAT64 = 7
|
||||
DATE = 8
|
||||
VECTOR = 9
|
||||
|
||||
|
||||
class Column(object):
|
||||
"""
|
||||
Table column description
|
||||
|
||||
:type type: ColumnType
|
||||
:param type: (Required) type of the column
|
||||
|
||||
:type name: str
|
||||
:param name: (Required) name of the column
|
||||
|
||||
"""
|
||||
def __init__(self, name=None, type=AbstractColumnType.INVALID):
|
||||
self.type = type
|
||||
self.name = name
|
||||
|
||||
|
||||
class VectorColumn(Column):
|
||||
"""
|
||||
Table vector column description
|
||||
|
||||
:type dimension: int, int64
|
||||
:param dimension: (Required) vector dimension
|
||||
|
||||
:type index_type: string IndexType
|
||||
:param index_type: (Required) IndexType
|
||||
|
||||
:type store_raw_vector: bool
|
||||
:param store_raw_vector: (Required) Is vector self stored in the table
|
||||
|
||||
`Column`:
|
||||
:type name: str
|
||||
:param name: (Required) Name of the column
|
||||
|
||||
:type type: ColumnType
|
||||
:param type: (Required) Default type is ColumnType.VECTOR, can't change
|
||||
|
||||
"""
|
||||
def __init__(self, name,
|
||||
dimension=0,
|
||||
index_type=None,
|
||||
store_raw_vector=False,
|
||||
type=None):
|
||||
self.dimension = dimension
|
||||
self.index_type = index_type
|
||||
self.store_raw_vector = store_raw_vector
|
||||
super(VectorColumn, self).__init__(name, type=type)
|
||||
class IndexType(IntEnum):
|
||||
INVALIDE = 0
|
||||
IDMAP = 1
|
||||
IVFLAT = 2
|
||||
|
||||
|
||||
class TableSchema(object):
|
||||
@ -74,30 +14,26 @@ class TableSchema(object):
|
||||
:type table_name: str
|
||||
:param table_name: (Required) name of table
|
||||
|
||||
:type vector_columns: list[VectorColumn]
|
||||
:param vector_columns: (Required) a list of VectorColumns,
|
||||
:type index_type: IndexType
|
||||
:param index_type: (Optional) index type, default = 0
|
||||
|
||||
Stores different types of vectors
|
||||
`IndexType`: 0-invalid, 1-idmap, 2-ivflat
|
||||
|
||||
:type attribute_columns: list[Column]
|
||||
:param attribute_columns: (Optional) Columns description
|
||||
:type dimension: int64
|
||||
:param dimension: (Required) dimension of vector
|
||||
|
||||
List of `Columns` whose type isn't VECTOR
|
||||
|
||||
:type partition_column_names: list[str]
|
||||
:param partition_column_names: (Optional) Partition column name
|
||||
|
||||
`Partition columns` are `attribute columns`, the number of
|
||||
partition columns may be less than or equal to attribute columns,
|
||||
this param only stores `column name`
|
||||
:type store_raw_vector: bool
|
||||
:param store_raw_vector: (Optional) default = False
|
||||
|
||||
"""
|
||||
def __init__(self, table_name, vector_columns,
|
||||
attribute_columns, partition_column_names, **kwargs):
|
||||
def __init__(self, table_name,
|
||||
dimension=0,
|
||||
index_type=IndexType.INVALIDE,
|
||||
store_raw_vector=False):
|
||||
self.table_name = table_name
|
||||
self.vector_columns = vector_columns
|
||||
self.attribute_columns = attribute_columns
|
||||
self.partition_column_names = partition_column_names
|
||||
self.index_type = index_type
|
||||
self.dimension = dimension
|
||||
self.store_raw_vector = store_raw_vector
|
||||
|
||||
|
||||
class Range(object):
|
||||
@ -105,10 +41,10 @@ class Range(object):
|
||||
Range information
|
||||
|
||||
:type start: str
|
||||
:param start: (Required) Range start value
|
||||
:param start: Range start value
|
||||
|
||||
:type end: str
|
||||
:param end: (Required) Range end value
|
||||
:param end: Range end value
|
||||
|
||||
"""
|
||||
def __init__(self, start, end):
|
||||
@ -116,97 +52,37 @@ class Range(object):
|
||||
self.end = end
|
||||
|
||||
|
||||
class CreateTablePartitionParam(object):
|
||||
"""
|
||||
Create table partition parameters
|
||||
|
||||
:type table_name: str
|
||||
:param table_name: (Required) Table name,
|
||||
VECTOR/FLOAT32/FLOAT64 ColumnType is not allowed for partition
|
||||
|
||||
:type partition_name: str
|
||||
:param partition_name: (Required) partition name, created partition name
|
||||
|
||||
:type column_name_to_range: dict{str : Range}
|
||||
:param column_name_to_range: (Required) Column name to PartitionRange dictionary
|
||||
"""
|
||||
# TODO Iterable
|
||||
def __init__(self, table_name, partition_name, column_name_to_range):
|
||||
self.table_name = table_name
|
||||
self.partition_name = partition_name
|
||||
self.column_name_to_range = column_name_to_range
|
||||
|
||||
|
||||
class DeleteTablePartitionParam(object):
|
||||
"""
|
||||
Delete table partition parameters
|
||||
|
||||
:type table_name: str
|
||||
:param table_name: (Required) Table name
|
||||
|
||||
:type partition_names: iterable, str
|
||||
:param partition_names: (Required) Partition name array
|
||||
|
||||
"""
|
||||
# TODO Iterable
|
||||
def __init__(self, table_name, partition_names):
|
||||
self.table_name = table_name
|
||||
self.partition_names = partition_names
|
||||
|
||||
|
||||
class RowRecord(object):
|
||||
"""
|
||||
Record inserted
|
||||
|
||||
:type column_name_to_vector: dict{str : list[float]}
|
||||
:param column_name_to_vector: (Required) Column name to vector map
|
||||
|
||||
:type column_name_to_attribute: dict{str: str}
|
||||
:param column_name_to_attribute: (Optional) Other attribute columns
|
||||
"""
|
||||
def __init__(self, column_name_to_vector, column_name_to_attribute):
|
||||
self.column_name_to_vector = column_name_to_vector
|
||||
self.column_name_to_attribute = column_name_to_attribute
|
||||
|
||||
|
||||
class QueryRecord(object):
|
||||
"""
|
||||
Query record
|
||||
|
||||
:type column_name_to_vector: (Required) dict{str : list[float]}
|
||||
:param column_name_to_vector: Query vectors, column name to vector map
|
||||
|
||||
:type selected_columns: list[str]
|
||||
:param selected_columns: (Optional) Output column array
|
||||
|
||||
:type name_to_partition_ranges: dict{str : list[Range]}
|
||||
:param name_to_partition_ranges: (Optional) Range used to select partitions
|
||||
:type vector_data: binary str
|
||||
:param vector_data: (Required) a vector
|
||||
|
||||
"""
|
||||
def __init__(self, column_name_to_vector, selected_columns, name_to_partition_ranges):
|
||||
self.column_name_to_vector = column_name_to_vector
|
||||
self.selected_columns = selected_columns
|
||||
self.name_to_partition_ranges = name_to_partition_ranges
|
||||
def __init__(self, vector_data):
|
||||
self.vector_data = vector_data
|
||||
|
||||
|
||||
class QueryResult(object):
|
||||
"""
|
||||
Query result
|
||||
|
||||
:type id: int
|
||||
:param id: Output result
|
||||
:type id: int64
|
||||
:param id: id of the vector
|
||||
|
||||
:type score: float
|
||||
:param score: Vector similarity 0 <= score <= 100
|
||||
|
||||
:type column_name_to_attribute: dict{str : str}
|
||||
:param column_name_to_attribute: Other columns
|
||||
|
||||
"""
|
||||
def __init__(self, id, score, column_name_to_attribute):
|
||||
def __init__(self, id, score):
|
||||
self.id = id
|
||||
self.score = score
|
||||
self.column_name_to_value = column_name_to_attribute
|
||||
|
||||
def __repr__(self):
|
||||
L = ['%s=%r' % (key, value)
|
||||
for key, value in self.__dict__.items()]
|
||||
return '%s(%s)' % (self.__class__.__name__, ', '.join(L))
|
||||
|
||||
|
||||
class TopKQueryResult(object):
|
||||
@ -220,6 +96,12 @@ class TopKQueryResult(object):
|
||||
def __init__(self, query_results):
|
||||
self.query_results = query_results
|
||||
|
||||
def __repr__(self):
|
||||
L = ['%s=%r' % (key, value)
|
||||
for key, value in self.__dict__.items()]
|
||||
return '%s(%s)' % (self.__class__.__name__, ', '.join(L))
|
||||
|
||||
|
||||
|
||||
def _abstract():
|
||||
raise NotImplementedError('You need to override this function')
|
||||
@ -232,114 +114,71 @@ class ConnectIntf(object):
|
||||
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create():
|
||||
"""Create a connection instance and return it
|
||||
should be implemented
|
||||
|
||||
:return connection: Connection
|
||||
"""
|
||||
_abstract()
|
||||
|
||||
@staticmethod
|
||||
def destroy(connection):
|
||||
"""Destroy the connection instance
|
||||
should be implemented
|
||||
|
||||
:type connection: Connection
|
||||
:param connection: The connection instance to be destroyed
|
||||
|
||||
:return bool, return True if destroy is successful
|
||||
"""
|
||||
_abstract()
|
||||
|
||||
def connect(self, param=None, uri=None):
|
||||
def connect(self, host=None, port=None, uri=None):
|
||||
"""
|
||||
Connect method should be called before any operations
|
||||
Server will be connected after connect return OK
|
||||
should be implemented
|
||||
Should be implemented
|
||||
|
||||
:type param: ConnectParam
|
||||
:param param: ConnectParam
|
||||
:type host: str
|
||||
:param host: host
|
||||
|
||||
:type port: str
|
||||
:param port: port
|
||||
|
||||
:type uri: str
|
||||
:param uri: uri param
|
||||
:param uri: (Optional) uri
|
||||
|
||||
:return: Status, indicate if connect is successful
|
||||
:return Status, indicate if connect is successful
|
||||
"""
|
||||
if (not param and not uri) or (param and uri):
|
||||
raise ConnectParamMissingError('You need to parse exact one param')
|
||||
_abstract()
|
||||
|
||||
def connected(self):
|
||||
"""
|
||||
connected, connection status
|
||||
should be implemented
|
||||
Should be implemented
|
||||
|
||||
:return: Status, indicate if connect is successful
|
||||
:return Status, indicate if connect is successful
|
||||
"""
|
||||
_abstract()
|
||||
|
||||
def disconnect(self):
|
||||
"""
|
||||
Disconnect, server will be disconnected after disconnect return OK
|
||||
should be implemented
|
||||
Disconnect, server will be disconnected after disconnect return SUCCESS
|
||||
Should be implemented
|
||||
|
||||
:return: Status, indicate if connect is successful
|
||||
:return Status, indicate if connect is successful
|
||||
"""
|
||||
_abstract()
|
||||
|
||||
def create_table(self, param):
|
||||
"""
|
||||
Create table
|
||||
should be implemented
|
||||
Should be implemented
|
||||
|
||||
:type param: TableSchema
|
||||
:param param: provide table information to be created
|
||||
|
||||
:return: Status, indicate if connect is successful
|
||||
:return Status, indicate if connect is successful
|
||||
"""
|
||||
_abstract()
|
||||
|
||||
def delete_table(self, table_name):
|
||||
"""
|
||||
Delete table
|
||||
should be implemented
|
||||
Should be implemented
|
||||
|
||||
:type table_name: str
|
||||
:param table_name: table_name of the deleting table
|
||||
|
||||
:return: Status, indicate if connect is successful
|
||||
:return Status, indicate if connect is successful
|
||||
"""
|
||||
_abstract()
|
||||
|
||||
def create_table_partition(self, param):
|
||||
"""
|
||||
Create table partition
|
||||
should be implemented
|
||||
|
||||
:type param: CreateTablePartitionParam
|
||||
:param param: provide partition information
|
||||
|
||||
:return: Status, indicate if table partition is created successfully
|
||||
"""
|
||||
_abstract()
|
||||
|
||||
def delete_table_partition(self, param):
|
||||
"""
|
||||
Delete table partition
|
||||
should be implemented
|
||||
|
||||
:type param: DeleteTablePartitionParam
|
||||
:param param: provide partition information to be deleted
|
||||
:return: Status, indicate if partition is deleted successfully
|
||||
"""
|
||||
_abstract()
|
||||
|
||||
def add_vector(self, table_name, records):
|
||||
def add_vectors(self, table_name, records):
|
||||
"""
|
||||
Add vectors to table
|
||||
should be implemented
|
||||
Should be implemented
|
||||
|
||||
:type table_name: str
|
||||
:param table_name: table name been inserted
|
||||
@ -347,27 +186,31 @@ class ConnectIntf(object):
|
||||
:type records: list[RowRecord]
|
||||
:param records: list of vectors been inserted
|
||||
|
||||
:returns:
|
||||
:returns
|
||||
Status : indicate if vectors inserted successfully
|
||||
ids :list of id, after inserted every vector is given a id
|
||||
"""
|
||||
_abstract()
|
||||
|
||||
def search_vector(self, table_name, query_records, top_k):
|
||||
def search_vectors(self, table_name, query_records, query_ranges, top_k):
|
||||
"""
|
||||
Query vectors in a table
|
||||
should be implemented
|
||||
Should be implemented
|
||||
|
||||
:type table_name: str
|
||||
:param table_name: table name been queried
|
||||
|
||||
:type query_records: list[QueryRecord]
|
||||
:type query_records: list[RowRecord]
|
||||
:param query_records: all vectors going to be queried
|
||||
|
||||
:type query_ranges: list[Range]
|
||||
:param query_ranges: Optional ranges for conditional search.
|
||||
If not specified, search whole table
|
||||
|
||||
:type top_k: int
|
||||
:param top_k: how many similar vectors will be searched
|
||||
|
||||
:returns:
|
||||
:returns
|
||||
Status: indicate if query is successful
|
||||
query_results: list[TopKQueryResult]
|
||||
"""
|
||||
@ -376,23 +219,37 @@ class ConnectIntf(object):
|
||||
def describe_table(self, table_name):
|
||||
"""
|
||||
Show table information
|
||||
should be implemented
|
||||
Should be implemented
|
||||
|
||||
:type table_name: str
|
||||
:param table_name: which table to be shown
|
||||
|
||||
:returns:
|
||||
:returns
|
||||
Status: indicate if query is successful
|
||||
table_schema: TableSchema, given when operation is successful
|
||||
"""
|
||||
_abstract()
|
||||
|
||||
def get_table_row_count(self, table_name):
|
||||
"""
|
||||
Get table row count
|
||||
Should be implemented
|
||||
|
||||
:type table_name, str
|
||||
:param table_name, target table name.
|
||||
|
||||
:returns
|
||||
Status: indicate if operation is successful
|
||||
count: int, table row count
|
||||
"""
|
||||
_abstract()
|
||||
|
||||
def show_tables(self):
|
||||
"""
|
||||
Show all tables in database
|
||||
should be implemented
|
||||
|
||||
:return:
|
||||
:return
|
||||
Status: indicate if this operation is successful
|
||||
tables: list[str], list of table names
|
||||
"""
|
||||
@ -403,31 +260,28 @@ class ConnectIntf(object):
|
||||
Provide client version
|
||||
should be implemented
|
||||
|
||||
:return: Client version
|
||||
:return: str, client version
|
||||
"""
|
||||
_abstract()
|
||||
pass
|
||||
|
||||
def server_version(self):
|
||||
"""
|
||||
Provide server version
|
||||
should be implemented
|
||||
|
||||
:return: Server version
|
||||
:return: str, server version
|
||||
"""
|
||||
_abstract()
|
||||
|
||||
def server_status(self, cmd):
|
||||
"""
|
||||
Provide server status
|
||||
should be implemented
|
||||
# TODO What is cmd
|
||||
:type cmd
|
||||
:param cmd
|
||||
:type cmd, str
|
||||
|
||||
:return: Server status
|
||||
:return: str, server status
|
||||
"""
|
||||
_abstract()
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
@ -9,21 +9,20 @@ from thrift.Thrift import TException, TApplicationException, TType
|
||||
from megasearch.thrift import MegasearchService
|
||||
from megasearch.thrift import ttypes
|
||||
from client.Abstract import (
|
||||
ConnectIntf, TableSchema,
|
||||
AbstactIndexType, AbstractColumnType,
|
||||
Column,
|
||||
VectorColumn, Range,
|
||||
CreateTablePartitionParam,
|
||||
DeleteTablePartitionParam,
|
||||
RowRecord, QueryRecord,
|
||||
QueryResult, TopKQueryResult
|
||||
ConnectIntf,
|
||||
TableSchema,
|
||||
IndexType,
|
||||
Range,
|
||||
RowRecord,
|
||||
QueryResult,
|
||||
TopKQueryResult
|
||||
)
|
||||
|
||||
from client.Status import Status
|
||||
from client.Exceptions import (
|
||||
RepeatingConnectError, ConnectParamMissingError,
|
||||
RepeatingConnectError,
|
||||
DisconnectNotConnectedClientError,
|
||||
ParamError, NotConnectError
|
||||
NotConnectError
|
||||
)
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
@ -32,125 +31,35 @@ __VERSION__ = '0.0.1'
|
||||
__NAME__ = 'Thrift_Client'
|
||||
|
||||
|
||||
class IndexType(AbstactIndexType):
|
||||
# TODO thrift in IndexType
|
||||
RAW = '1'
|
||||
IVFFLAT = '2'
|
||||
|
||||
|
||||
class ColumnType(AbstractColumnType):
|
||||
|
||||
FLOAT32 = 6
|
||||
FLOAT64 = 7
|
||||
DATE = 8
|
||||
|
||||
INVALID = TType.STOP
|
||||
INT8 = TType.I08
|
||||
INT16 = TType.I16
|
||||
INT32 = TType.I32
|
||||
INT64 = TType.I64
|
||||
VECTOR = TType.LIST
|
||||
|
||||
# TODO Required and Optional
|
||||
# TODO Examples
|
||||
# TODO ORM
|
||||
class Prepare(object):
|
||||
|
||||
@classmethod
|
||||
def column(cls, name, type):
|
||||
"""
|
||||
Table column param
|
||||
:param type: (Required) ColumnType, type of the column
|
||||
:param name: (Required) str, name of the column
|
||||
|
||||
:return Column
|
||||
"""
|
||||
temp_column = Column(name=name, type=type)
|
||||
return ttypes.Column(name=temp_column.name, type=temp_column.type)
|
||||
|
||||
@classmethod
|
||||
def vector_column(cls, name, dimension,
|
||||
index_type=IndexType.RAW,
|
||||
store_raw_vector=False):
|
||||
"""
|
||||
Table vector column description
|
||||
|
||||
:param dimension: (Required) int64, vector dimension
|
||||
:param index_type: (Required) IndexType
|
||||
:param store_raw_vector: (Required) Bool
|
||||
|
||||
`Column`:
|
||||
:param name: (Required) Name of the column
|
||||
:param type: (Required) Default type is ColumnType.VECTOR, can't change
|
||||
|
||||
:return VectorColumn
|
||||
def table_schema(cls,
|
||||
table_name, *,
|
||||
dimension,
|
||||
index_type,
|
||||
store_raw_vector):
|
||||
"""
|
||||
|
||||
temp = VectorColumn(name=name, dimension=dimension,
|
||||
index_type=index_type, store_raw_vector=store_raw_vector)
|
||||
base = ttypes.Column(name=temp.name, type=ColumnType.VECTOR)
|
||||
return ttypes.VectorColumn(base=base, dimension=temp.dimension,
|
||||
store_raw_vector=temp.store_raw_vector,
|
||||
index_type=temp.index_type)
|
||||
|
||||
# Without IndexType
|
||||
# temp = VectorColumn(name=name, dimension=dimension,
|
||||
# store_raw_vector=store_raw_vector)
|
||||
# return ttypes.VectorColumn(base=base, dimension=temp.dimension,
|
||||
# store_raw_vector=temp.store_raw_vector)
|
||||
|
||||
@classmethod
|
||||
def table_schema(cls, table_name,
|
||||
vector_columns,
|
||||
attribute_columns,
|
||||
partition_column_names):
|
||||
"""
|
||||
|
||||
:param table_name: (Required) Name of the table
|
||||
:param vector_columns: (Required) List of VectorColumns
|
||||
|
||||
`VectorColumn`:
|
||||
- dimension: int, default = 0
|
||||
Dimension of the vector, different vector_columns'
|
||||
dimension may vary
|
||||
- index_type: (optional) IndexType, default=IndexType.RAW
|
||||
Vector's index type
|
||||
- store_raw_vector : (optional) bool, default=False
|
||||
- name: str
|
||||
Name of the column
|
||||
- type: ColumnType, default=ColumnType.VECTOR, can't change
|
||||
|
||||
:param attribute_columns: (Optional) List of Columns. Attribute columns are Columns,
|
||||
whose types aren't ColumnType.VECTOR
|
||||
|
||||
`Column`:
|
||||
- name: str
|
||||
- type: ColumnType, default=ColumnType.INVALID
|
||||
|
||||
:param partition_column_names: (Optional) List of str.
|
||||
|
||||
Partition columns name
|
||||
indicates which attribute columns is used for partition, can
|
||||
have lots of partition columns as long as:
|
||||
-> No. partition_column_names <= No. attribute_columns
|
||||
-> partition_column_names IN attribute_column_names
|
||||
:param table_name: str, (Required) name of table
|
||||
:param index_type: IndexType, (Required) index type, default = IndexType.INVALID
|
||||
:param dimension: int64, (Optional) dimension of the table
|
||||
:param store_raw_vector: bool, (Optional) default = False
|
||||
|
||||
:return: TableSchema
|
||||
"""
|
||||
temp = TableSchema(table_name,vector_columns,
|
||||
attribute_columns,
|
||||
partition_column_names)
|
||||
temp = TableSchema(table_name,dimension, index_type, store_raw_vector)
|
||||
|
||||
return ttypes.TableSchema(table_name=temp.table_name,
|
||||
vector_column_array=temp.vector_columns,
|
||||
attribute_column_array=temp.attribute_columns,
|
||||
partition_column_name_array=temp.partition_column_names)
|
||||
dimension=dimension,
|
||||
index_type=index_type,
|
||||
store_raw_vector=store_raw_vector)
|
||||
|
||||
@classmethod
|
||||
def range(cls, start, end):
|
||||
"""
|
||||
:param start: (Required) Partition range start value
|
||||
:param end: (Required) Partition range end value
|
||||
:param start: str, (Required) range start
|
||||
:param end: str (Required) range end
|
||||
|
||||
:return Range
|
||||
"""
|
||||
@ -158,142 +67,66 @@ class Prepare(object):
|
||||
return ttypes.Range(start_value=temp.start, end_value=temp.end)
|
||||
|
||||
@classmethod
|
||||
def create_table_partition_param(cls,
|
||||
table_name,
|
||||
partition_name,
|
||||
column_name_to_range):
|
||||
def row_record(cls, vector_data):
|
||||
"""
|
||||
Create table partition parameters
|
||||
:param table_name: (Required) str, Table name,
|
||||
VECTOR/FLOAT32/FLOAT64 ColumnType is not allowed for partition
|
||||
:param partition_name: (Required) str partition name, created partition name
|
||||
:param column_name_to_range: (Required) dict, column name to partition range dictionary
|
||||
Record inserted
|
||||
|
||||
:param vector_data: float binary str, (Required) a binary str
|
||||
|
||||
:return CreateTablePartitionParam
|
||||
"""
|
||||
temp = CreateTablePartitionParam(table_name=table_name,
|
||||
partition_name=partition_name,
|
||||
column_name_to_range=column_name_to_range)
|
||||
return ttypes.CreateTablePartitionParam(table_name=temp.table_name,
|
||||
partition_name=temp.partition_name,
|
||||
range_map=temp.column_name_to_range)
|
||||
|
||||
@classmethod
|
||||
def delete_table_partition_param(cls, table_name, partition_names):
|
||||
"""
|
||||
Delete table partition parameters
|
||||
:param table_name: (Required) Table name
|
||||
:param partition_names: (Required) List of partition names
|
||||
|
||||
:return DeleteTablePartitionParam
|
||||
"""
|
||||
temp = DeleteTablePartitionParam(table_name=table_name,
|
||||
partition_names=partition_names)
|
||||
return ttypes.DeleteTablePartitionParam(table_name=table_name,
|
||||
partition_name_array=partition_names)
|
||||
|
||||
@classmethod
|
||||
def row_record(cls, column_name_to_vector, column_name_to_attribute):
|
||||
"""
|
||||
:param column_name_to_vector: (Required) dict{str : list[float]}
|
||||
Column name to vector map
|
||||
|
||||
:param column_name_to_attribute: (Optional) dict{str: str}
|
||||
Other attribute columns
|
||||
"""
|
||||
temp = RowRecord(column_name_to_vector=column_name_to_vector,
|
||||
column_name_to_attribute=column_name_to_attribute)
|
||||
return ttypes.RowRecord(vector_map=temp.column_name_to_vector,
|
||||
attribute_map=temp.column_name_to_attribute)
|
||||
|
||||
@classmethod
|
||||
def query_record(cls, column_name_to_vector,
|
||||
selected_columns, name_to_partition_ranges):
|
||||
"""
|
||||
:param column_name_to_vector: (Required) dict{str : list[float]}
|
||||
Query vectors, column name to vector map
|
||||
|
||||
:param selected_columns: (Optional) list[str_column_name]
|
||||
List of Output columns
|
||||
|
||||
:param name_to_partition_ranges: (Optional) dict{str : list[Range]}
|
||||
Partition Range used to search
|
||||
|
||||
`Range`:
|
||||
:param start: Partition range start value
|
||||
:param end: Partition range end value
|
||||
|
||||
:return QueryRecord
|
||||
"""
|
||||
temp = QueryRecord(column_name_to_vector=column_name_to_vector,
|
||||
selected_columns=selected_columns,
|
||||
name_to_partition_ranges=name_to_partition_ranges)
|
||||
return ttypes.QueryRecord(vector_map=temp.column_name_to_vector,
|
||||
selected_column_array=temp.selected_columns,
|
||||
partition_filter_column_map=name_to_partition_ranges)
|
||||
temp = RowRecord(vector_data)
|
||||
return ttypes.RowRecord(vector_data=temp.vector_data)
|
||||
|
||||
|
||||
class MegaSearch(ConnectIntf):
|
||||
|
||||
def __init__(self):
|
||||
self.transport = None
|
||||
self.client = None
|
||||
self.status = None
|
||||
self._transport = None
|
||||
self._client = None
|
||||
|
||||
def __repr__(self):
|
||||
return '{}'.format(self.status)
|
||||
|
||||
@staticmethod
|
||||
def create():
|
||||
# TODO in python, maybe this method is useless
|
||||
return MegaSearch()
|
||||
|
||||
@staticmethod
|
||||
def destroy(connection):
|
||||
"""Destroy the connection instance"""
|
||||
# TODO in python, maybe this method is useless
|
||||
|
||||
pass
|
||||
|
||||
def connect(self, host='localhost', port='9090', uri=None):
|
||||
# TODO URI
|
||||
if self.status and self.status == Status(message='Connected'):
|
||||
if self.status and self.status == Status.SUCCESS:
|
||||
raise RepeatingConnectError("You have already connected!")
|
||||
|
||||
transport = TSocket.TSocket(host=host, port=port)
|
||||
self.transport = TTransport.TBufferedTransport(transport)
|
||||
self._transport = TTransport.TBufferedTransport(transport)
|
||||
protocol = TBinaryProtocol.TBinaryProtocol(transport)
|
||||
self.client = MegasearchService.Client(protocol)
|
||||
self._client = MegasearchService.Client(protocol)
|
||||
|
||||
try:
|
||||
transport.open()
|
||||
self.status = Status(Status.OK, 'Connected')
|
||||
self.status = Status(Status.SUCCESS, 'Connected')
|
||||
LOGGER.info('Connected!')
|
||||
|
||||
except (TTransport.TTransportException, TException) as e:
|
||||
self.status = Status(Status.INVALID, message=str(e))
|
||||
self.status = Status(Status.CONNECT_FAILED, message=str(e))
|
||||
LOGGER.error('logger.error: {}'.format(self.status))
|
||||
finally:
|
||||
return self.status
|
||||
|
||||
@property
|
||||
def connected(self):
|
||||
return self.status == Status()
|
||||
return self.status == Status.SUCCESS
|
||||
|
||||
def disconnect(self):
|
||||
|
||||
if not self.transport:
|
||||
if not self._transport:
|
||||
raise DisconnectNotConnectedClientError('Error')
|
||||
|
||||
try:
|
||||
|
||||
self.transport.close()
|
||||
self._transport.close()
|
||||
LOGGER.info('Client Disconnected!')
|
||||
self.status = None
|
||||
|
||||
except TException as e:
|
||||
return Status(Status.INVALID, str(e))
|
||||
return Status(Status.OK, 'Disconnected')
|
||||
return Status(Status.PERMISSION_DENIED, str(e))
|
||||
return Status(Status.SUCCESS, 'Disconnected')
|
||||
|
||||
def create_table(self, param):
|
||||
"""Create table
|
||||
@ -304,15 +137,14 @@ class MegaSearch(ConnectIntf):
|
||||
|
||||
:return: Status, indicate if operation is successful
|
||||
"""
|
||||
if not self.client:
|
||||
if not self._client:
|
||||
raise NotConnectError('Please Connect to the server first!')
|
||||
|
||||
try:
|
||||
LOGGER.error(param)
|
||||
self.client.CreateTable(param)
|
||||
self._client.CreateTable(param)
|
||||
except (TApplicationException, ) as e:
|
||||
LOGGER.error('Unable to create table')
|
||||
return Status(Status.INVALID, str(e))
|
||||
return Status(Status.PERMISSION_DENIED, str(e))
|
||||
return Status(message='Table {} created!'.format(param.table_name))
|
||||
|
||||
def delete_table(self, table_name):
|
||||
@ -323,48 +155,13 @@ class MegaSearch(ConnectIntf):
|
||||
:return: Status, indicate if operation is successful
|
||||
"""
|
||||
try:
|
||||
self.client.DeleteTable(table_name)
|
||||
self._client.DeleteTable(table_name)
|
||||
except (TApplicationException, TException) as e:
|
||||
LOGGER.error('Unable to delete table {}'.format(table_name))
|
||||
return Status(Status.INVALID, str(e))
|
||||
return Status(Status.PERMISSION_DENIED, str(e))
|
||||
return Status(message='Table {} deleted!'.format(table_name))
|
||||
|
||||
def create_table_partition(self, param):
|
||||
"""
|
||||
Create table partition
|
||||
|
||||
:type param: CreateTablePartitionParam, provide partition information
|
||||
|
||||
`Please use Prepare.create_table_partition_param generate param`
|
||||
|
||||
:return: Status, indicate if table partition is created successfully
|
||||
"""
|
||||
try:
|
||||
self.client.CreateTablePartition(param)
|
||||
except (TApplicationException, TException) as e:
|
||||
LOGGER.error('{}'.format(e))
|
||||
return Status(Status.INVALID, str(e))
|
||||
return Status(message='Table partition created successfully!')
|
||||
|
||||
def delete_table_partition(self, param):
|
||||
"""
|
||||
Delete table partition
|
||||
|
||||
:type param: DeleteTablePartitionParam
|
||||
:param param: provide partition information to be deleted
|
||||
|
||||
`Please use Prepare.delete_table_partition_param generate param`
|
||||
|
||||
:return: Status, indicate if partition is deleted successfully
|
||||
"""
|
||||
try:
|
||||
self.client.DeleteTablePartition(param)
|
||||
except (TApplicationException, TException) as e:
|
||||
LOGGER.error('{}'.format(e))
|
||||
return Status(Status.INVALID, str(e))
|
||||
return Status(message='Table partition deleted successfully!')
|
||||
|
||||
def add_vector(self, table_name, records):
|
||||
def add_vectors(self, table_name, records):
|
||||
"""
|
||||
Add vectors to table
|
||||
|
||||
@ -378,13 +175,13 @@ class MegaSearch(ConnectIntf):
|
||||
ids :list of id, after inserted every vector is given a id
|
||||
"""
|
||||
try:
|
||||
ids = self.client.AddVector(table_name=table_name, record_array=records)
|
||||
ids = self._client.AddVector(table_name=table_name, record_array=records)
|
||||
except (TApplicationException, TException) as e:
|
||||
LOGGER.error('{}'.format(e))
|
||||
return Status(Status.INVALID, str(e)), None
|
||||
return Status(message='Vector added successfully!'), ids
|
||||
return Status(Status.PERMISSION_DENIED, str(e)), None
|
||||
return Status(message='Vectors added successfully!'), ids
|
||||
|
||||
def search_vector(self, table_name, query_records, top_k):
|
||||
def search_vectors(self, table_name, top_k, query_records, query_ranges=None):
|
||||
"""
|
||||
Query vectors in a table
|
||||
|
||||
@ -394,20 +191,29 @@ class MegaSearch(ConnectIntf):
|
||||
`Please use Prepare.query_record generate QueryRecord`
|
||||
|
||||
:param top_k: int, how many similar vectors will be searched
|
||||
:param query_ranges, (Optional) list[Range], search range
|
||||
|
||||
:returns:
|
||||
Status: indicate if query is successful
|
||||
query_results: list[TopKQueryResult], return when operation is successful
|
||||
res: list[TopKQueryResult], return when operation is successful
|
||||
"""
|
||||
# TODO topk_query_results
|
||||
res = []
|
||||
try:
|
||||
topk_query_results = self.client.SearchVector(
|
||||
table_name=table_name, query_record_array=query_records, topk=top_k)
|
||||
top_k_query_results = self._client.SearchVector(
|
||||
table_name=table_name,
|
||||
query_record_array=query_records,
|
||||
query_range_array=query_ranges,
|
||||
topk=top_k)
|
||||
|
||||
if top_k_query_results:
|
||||
for top_k in top_k_query_results:
|
||||
res.append(TopKQueryResult([QueryResult(qr.id, qr.score)
|
||||
for qr in top_k.query_result_arrays]))
|
||||
|
||||
except (TApplicationException, TException) as e:
|
||||
LOGGER.error('{}'.format(e))
|
||||
return Status(Status.INVALID, str(e)), None
|
||||
return Status(message='Success!'), topk_query_results
|
||||
return Status(Status.PERMISSION_DENIED, str(e)), None
|
||||
return Status(message='Success!'), res
|
||||
|
||||
def describe_table(self, table_name):
|
||||
"""
|
||||
@ -420,12 +226,14 @@ class MegaSearch(ConnectIntf):
|
||||
table_schema: TableSchema, return when operation is successful
|
||||
"""
|
||||
try:
|
||||
thrift_table_schema = self.client.DescribeTable(table_name)
|
||||
temp = self._client.DescribeTable(table_name)
|
||||
|
||||
# res = TableSchema(table_name=temp.table_name, dimension=temp.dimension,
|
||||
# index_type=temp.index_type, store_raw_vector=temp.store_raw_vector)
|
||||
except (TApplicationException, TException) as e:
|
||||
LOGGER.error('{}'.format(e))
|
||||
return Status(Status.INVALID, str(e)), None
|
||||
# TODO Table Schema
|
||||
return Status(message='Success!'), thrift_table_schema
|
||||
return Status(Status.PERMISSION_DENIED, str(e)), None
|
||||
return Status(message='Success!'), temp
|
||||
|
||||
def show_tables(self):
|
||||
"""
|
||||
@ -437,12 +245,36 @@ class MegaSearch(ConnectIntf):
|
||||
is successful
|
||||
"""
|
||||
try:
|
||||
tables = self.client.ShowTables()
|
||||
res = self._client.ShowTables()
|
||||
tables = []
|
||||
if res:
|
||||
tables, _ = res
|
||||
|
||||
except (TApplicationException, TException) as e:
|
||||
LOGGER.error('{}'.format(e))
|
||||
return Status(Status.INVALID, str(e)), None
|
||||
return Status(Status.PERMISSION_DENIED, str(e)), None
|
||||
return Status(message='Success!'), tables
|
||||
|
||||
def get_table_row_count(self, table_name):
|
||||
"""
|
||||
Get table row count
|
||||
|
||||
:type table_name, str
|
||||
:param table_name, target table name.
|
||||
|
||||
:returns:
|
||||
Status: indicate if operation is successful
|
||||
res: int, table row count
|
||||
|
||||
"""
|
||||
try:
|
||||
count, _ = self._client.GetTableRowCount(table_name)
|
||||
|
||||
except (TApplicationException, TException) as e:
|
||||
LOGGER.error('{}'.format(e))
|
||||
return Status(Status.PERMISSION_DENIED, str(e)), None
|
||||
return Status(message='Success'), count
|
||||
|
||||
def client_version(self):
|
||||
"""
|
||||
Provide client version
|
||||
@ -457,8 +289,10 @@ class MegaSearch(ConnectIntf):
|
||||
|
||||
:return: Server version
|
||||
"""
|
||||
# TODO How to get server version
|
||||
pass
|
||||
if not self.connected:
|
||||
raise NotConnectError('You have to connect first')
|
||||
|
||||
return self._client.Ping('version')
|
||||
|
||||
def server_status(self, cmd=None):
|
||||
"""
|
||||
@ -466,4 +300,7 @@ class MegaSearch(ConnectIntf):
|
||||
|
||||
:return: Server status
|
||||
"""
|
||||
return self.client.Ping(cmd)
|
||||
if not self.connected:
|
||||
raise NotConnectError('You have to connect first')
|
||||
|
||||
return self._client.Ping(cmd)
|
||||
|
@ -2,10 +2,6 @@ class ParamError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ConnectParamMissingError(ParamError):
|
||||
pass
|
||||
|
||||
|
||||
class ConnectError(ValueError):
|
||||
pass
|
||||
|
||||
|
@ -3,13 +3,15 @@ class Status(object):
|
||||
:attribute code : int (optional) default as ok
|
||||
:attribute message : str (optional) current status message
|
||||
"""
|
||||
OK = 0
|
||||
INVALID = 1
|
||||
UNKNOWN_ERROR = 2
|
||||
NOT_SUPPORTED = 3
|
||||
NOT_CONNECTED = 4
|
||||
SUCCESS = 0
|
||||
CONNECT_FAILED = 1
|
||||
PERMISSION_DENIED = 2
|
||||
TABLE_NOT_EXISTS = 3
|
||||
ILLEGAL_ARGUMENT = 4
|
||||
ILLEGAL_RANGE = 5
|
||||
ILLEGAL_DIMENSION = 6
|
||||
|
||||
def __init__(self, code=OK, message=None):
|
||||
def __init__(self, code=SUCCESS, message=None):
|
||||
self.code = code
|
||||
self.message = message
|
||||
|
||||
|
@ -1,79 +1,92 @@
|
||||
from client.Client import MegaSearch, Prepare, IndexType, ColumnType
|
||||
from client.Client import MegaSearch, Prepare, IndexType
|
||||
from client.Status import Status
|
||||
import time
|
||||
import random
|
||||
import struct
|
||||
from pprint import pprint
|
||||
|
||||
from megasearch.thrift import MegasearchService, ttypes
|
||||
|
||||
|
||||
def main():
|
||||
# Get client version
|
||||
mega = MegaSearch()
|
||||
print(mega.client_version())
|
||||
print('# Client version: {}'.format(mega.client_version()))
|
||||
|
||||
# Connect
|
||||
param = {'host': '192.168.1.129', 'port': '33001'}
|
||||
cnn_status = mega.connect(**param)
|
||||
print('Connect Status: {}'.format(cnn_status))
|
||||
print('# Connect Status: {}'.format(cnn_status))
|
||||
|
||||
# Check if connected
|
||||
is_connected = mega.connected
|
||||
print('Connect status: {}'.format(is_connected))
|
||||
print('# Is connected: {}'.format(is_connected))
|
||||
|
||||
# Create table with 1 vector column, 1 attribute column and 1 partition column
|
||||
# 1. prepare table_schema
|
||||
# Get server version
|
||||
print('# Server version: {}'.format(mega.server_version()))
|
||||
|
||||
# table_schema = Prepare.table_schema(
|
||||
# table_name='fake_table_name' + time.strftime('%H%M%S'),
|
||||
#
|
||||
# vector_columns=[Prepare.vector_column(
|
||||
# name='fake_vector_name' + time.strftime('%H%M%S'),
|
||||
# store_raw_vector=False,
|
||||
# dimension=256)],
|
||||
#
|
||||
# attribute_columns=[],
|
||||
#
|
||||
# partition_column_names=[]
|
||||
# )
|
||||
|
||||
# get server version
|
||||
print(mega.server_status('version'))
|
||||
print(mega.client.Ping('version'))
|
||||
# show tables and their description
|
||||
statu, tables = mega.show_tables()
|
||||
print(tables)
|
||||
|
||||
for table in tables:
|
||||
s,t = mega.describe_table(table)
|
||||
print('table: {}'.format(t))
|
||||
# Show tables and their description
|
||||
status, tables = mega.show_tables()
|
||||
print('# Show tables: {}'.format(tables))
|
||||
|
||||
# Create table
|
||||
# 1. create table schema
|
||||
table_schema_full = MegasearchService.TableSchema(
|
||||
table_name='fake' + time.strftime('%H%M%S'),
|
||||
# 01.Prepare data
|
||||
param = {
|
||||
'table_name': 'test'+ str(random.randint(0,999)),
|
||||
'dimension': 256,
|
||||
'index_type': IndexType.IDMAP,
|
||||
'store_raw_vector': False
|
||||
}
|
||||
|
||||
vector_column_array=[MegasearchService.VectorColumn(
|
||||
base=MegasearchService.Column(
|
||||
name='111',
|
||||
type=ttypes.TType.LIST
|
||||
),
|
||||
index_type="aaa",
|
||||
dimension=256,
|
||||
store_raw_vector=False,
|
||||
)],
|
||||
# 02.Create table
|
||||
res_status = mega.create_table(Prepare.table_schema(**param))
|
||||
print('# Create table status: {}'.format(res_status))
|
||||
|
||||
attribute_column_array=[],
|
||||
# Describe table
|
||||
table_name = 'test01'
|
||||
res_status, table = mega.describe_table(table_name)
|
||||
print('# Describe table status: {}'.format(res_status))
|
||||
print('# Describe table:{}'.format(table))
|
||||
|
||||
partition_column_name_array=[]
|
||||
)
|
||||
# Add vectors to table 'test01'
|
||||
# 01. Prepare data
|
||||
dim = 256
|
||||
# list of binary vectors
|
||||
vectors = [Prepare.row_record(struct.pack(str(dim)+'d',
|
||||
*[random.random()for _ in range(dim)]))
|
||||
for _ in range(20)]
|
||||
# 02. Add vectors
|
||||
status, ids = mega.add_vectors(table_name=table_name, records=vectors)
|
||||
print('# Add vector status: {}'.format(status))
|
||||
pprint(ids)
|
||||
|
||||
# 2. Create Table
|
||||
create_status = mega.client.CreateTable(param=table_schema_full)
|
||||
print('Create table status: {}'.format(create_status))
|
||||
# Search vectors
|
||||
q_records = [Prepare.row_record(struct.pack(str(dim) + 'd',
|
||||
*[random.random() for _ in range(dim)]))
|
||||
for _ in range(5)]
|
||||
param = {
|
||||
'table_name': 'test01',
|
||||
'query_records': q_records,
|
||||
'top_k': 10,
|
||||
# 'query_ranges': None # Optional
|
||||
}
|
||||
sta, results = mega.search_vectors(**param)
|
||||
print('# Search vectors status: {}'.format(sta))
|
||||
pprint(results)
|
||||
|
||||
# add_vector
|
||||
# Get table row count
|
||||
sta, result = mega.get_table_row_count(table_name)
|
||||
print('# Status: {}'.format(sta))
|
||||
print('# Count: {}'.format(result))
|
||||
|
||||
# Delete table 'test01'
|
||||
res_status = mega.delete_table(table_name)
|
||||
print('# Delete table status: {}'.format(res_status))
|
||||
|
||||
# Disconnect
|
||||
discnn_status = mega.disconnect()
|
||||
print('Disconnect Status{}'.format(discnn_status))
|
||||
print('# Disconnect Status: {}'.format(discnn_status))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
main()
|
||||
|
@ -3,17 +3,20 @@ import pytest
|
||||
import mock
|
||||
import faker
|
||||
import random
|
||||
import struct
|
||||
from faker.providers import BaseProvider
|
||||
|
||||
from client.Client import MegaSearch, Prepare, IndexType, ColumnType
|
||||
from client.Client import MegaSearch, Prepare
|
||||
from client.Abstract import IndexType, TableSchema
|
||||
from client.Status import Status
|
||||
from client.Exceptions import (
|
||||
RepeatingConnectError,
|
||||
DisconnectNotConnectedClientError
|
||||
)
|
||||
from megasearch.thrift import ttypes, MegasearchService
|
||||
|
||||
from thrift.transport.TSocket import TSocket
|
||||
from megasearch.thrift import ttypes, MegasearchService
|
||||
from thrift.transport import TTransport
|
||||
from thrift.transport.TTransport import TTransportException
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
@ -35,63 +38,37 @@ fake = faker.Faker()
|
||||
fake.add_provider(FakerProvider)
|
||||
|
||||
|
||||
def vector_column_factory():
|
||||
return {
|
||||
'name': fake.name(),
|
||||
'dimension': fake.dim(),
|
||||
'store_raw_vector': True
|
||||
}
|
||||
|
||||
|
||||
def column_factory():
|
||||
return {
|
||||
'name': fake.table_name(),
|
||||
'type': ColumnType.INT32
|
||||
}
|
||||
|
||||
|
||||
def range_factory():
|
||||
return {
|
||||
param = {
|
||||
'start': str(random.randint(1, 10)),
|
||||
'end': str(random.randint(11, 20)),
|
||||
}
|
||||
return Prepare.range(**param)
|
||||
|
||||
|
||||
def ranges_factory():
|
||||
return [range_factory() for _ in range(5)]
|
||||
|
||||
|
||||
def table_schema_factory():
|
||||
vec_params = [vector_column_factory() for i in range(10)]
|
||||
column_params = [column_factory() for i in range(5)]
|
||||
param = {
|
||||
'table_name': fake.table_name(),
|
||||
'vector_columns': [Prepare.vector_column(**pa) for pa in vec_params],
|
||||
'attribute_columns': [Prepare.column(**pa) for pa in column_params],
|
||||
'partition_column_names': [str(x) for x in range(2)]
|
||||
'dimension': random.randint(0, 999),
|
||||
'index_type': IndexType.IDMAP,
|
||||
'store_raw_vector': False
|
||||
}
|
||||
return Prepare.table_schema(**param)
|
||||
|
||||
|
||||
def create_table_partition_param_factory():
|
||||
param = {
|
||||
'table_name': fake.table_name(),
|
||||
'partition_name': fake.table_name(),
|
||||
'column_name_to_range': {fake.name(): range_factory() for _ in range(3)}
|
||||
}
|
||||
return Prepare.create_table_partition_param(**param)
|
||||
def row_record_factory(dimension):
|
||||
vec = [random.random() + random.randint(0,9) for _ in range(dimension)]
|
||||
bin_vec = struct.pack(str(dimension) + "d", *vec)
|
||||
|
||||
return Prepare.row_record(vector_data=bin_vec)
|
||||
|
||||
|
||||
def delete_table_partition_param_factory():
|
||||
param = {
|
||||
'table_name': fake.table_name(),
|
||||
'partition_names': [fake.name() for i in range(5)]
|
||||
}
|
||||
return Prepare.delete_table_partition_param(**param)
|
||||
|
||||
|
||||
def row_record_factory():
|
||||
param = {
|
||||
'column_name_to_vector': {fake.name(): [random.random() for i in range(256)]},
|
||||
'column_name_to_attribute': {fake.name(): fake.name()}
|
||||
}
|
||||
return Prepare.row_record(**param)
|
||||
def row_records_factory(dimension):
|
||||
return [row_record_factory(dimension) for _ in range(20)]
|
||||
|
||||
|
||||
class TestConnection:
|
||||
@ -103,9 +80,8 @@ class TestConnection:
|
||||
cnn = MegaSearch()
|
||||
|
||||
cnn.connect(**self.param)
|
||||
assert cnn.status == Status.OK
|
||||
assert cnn.status == Status.SUCCESS
|
||||
assert cnn.connected
|
||||
assert isinstance(cnn.client, MegasearchService.Client)
|
||||
|
||||
with pytest.raises(RepeatingConnectError):
|
||||
cnn.connect(**self.param)
|
||||
@ -114,12 +90,23 @@ class TestConnection:
|
||||
def test_false_connect(self):
|
||||
cnn = MegaSearch()
|
||||
|
||||
cnn.connect(self.param)
|
||||
assert cnn.status != Status.OK
|
||||
cnn.connect(**self.param)
|
||||
assert cnn.status != Status.SUCCESS
|
||||
|
||||
@mock.patch.object(TTransport.TBufferedTransport, 'close')
|
||||
@mock.patch.object(TSocket, 'open')
|
||||
def test_disconnected(self, close, open):
|
||||
close.return_value = None
|
||||
open.return_value = None
|
||||
|
||||
cnn = MegaSearch()
|
||||
cnn.connect(**self.param)
|
||||
|
||||
assert cnn.disconnect() == Status.SUCCESS
|
||||
|
||||
def test_disconnected_error(self):
|
||||
cnn = MegaSearch()
|
||||
cnn.connect_status = Status(Status.INVALID)
|
||||
cnn.connect_status = Status(Status.PERMISSION_DENIED)
|
||||
with pytest.raises(DisconnectNotConnectedClientError):
|
||||
cnn.disconnect()
|
||||
|
||||
@ -142,26 +129,26 @@ class TestTable:
|
||||
|
||||
param = table_schema_factory()
|
||||
res = client.create_table(param)
|
||||
assert res == Status.OK
|
||||
assert res == Status.SUCCESS
|
||||
|
||||
def test_false_create_table(self, client):
|
||||
param = table_schema_factory()
|
||||
with pytest.raises(TTransportException):
|
||||
res = client.create_table(param)
|
||||
LOGGER.error('{}'.format(res))
|
||||
assert res != Status.OK
|
||||
assert res != Status.SUCCESS
|
||||
|
||||
@mock.patch.object(MegasearchService.Client, 'DeleteTable')
|
||||
def test_delete_table(self, DeleteTable, client):
|
||||
DeleteTable.return_value = None
|
||||
table_name = 'fake_table_name'
|
||||
res = client.delete_table(table_name)
|
||||
assert res == Status.OK
|
||||
assert res == Status.SUCCESS
|
||||
|
||||
def test_false_delete_table(self, client):
|
||||
table_name = 'fake_table_name'
|
||||
res = client.delete_table(table_name)
|
||||
assert res != Status.OK
|
||||
assert res != Status.SUCCESS
|
||||
|
||||
|
||||
class TestVector:
|
||||
@ -176,70 +163,46 @@ class TestVector:
|
||||
cnn.connect(**param)
|
||||
return cnn
|
||||
|
||||
@mock.patch.object(MegasearchService.Client, 'CreateTablePartition')
|
||||
def test_create_table_partition(self, CreateTablePartition, client):
|
||||
CreateTablePartition.return_value = None
|
||||
|
||||
param = create_table_partition_param_factory()
|
||||
res = client.create_table_partition(param)
|
||||
assert res == Status.OK
|
||||
|
||||
def test_false_table_partition(self, client):
|
||||
param = create_table_partition_param_factory()
|
||||
res = client.create_table_partition(param)
|
||||
assert res != Status.OK
|
||||
|
||||
@mock.patch.object(MegasearchService.Client, 'DeleteTablePartition')
|
||||
def test_delete_table_partition(self, DeleteTablePartition, client):
|
||||
DeleteTablePartition.return_value = None
|
||||
|
||||
param = delete_table_partition_param_factory()
|
||||
res = client.delete_table_partition(param)
|
||||
assert res == Status.OK
|
||||
|
||||
def test_false_delete_table_partition(self, client):
|
||||
param = delete_table_partition_param_factory()
|
||||
res = client.delete_table_partition(param)
|
||||
assert res != Status.OK
|
||||
|
||||
@mock.patch.object(MegasearchService.Client, 'AddVector')
|
||||
def test_add_vector(self, AddVector, client):
|
||||
AddVector.return_value = None
|
||||
|
||||
param ={
|
||||
'table_name': fake.table_name(),
|
||||
'records': [row_record_factory() for _ in range(1000)]
|
||||
'records': row_records_factory(256)
|
||||
}
|
||||
res, ids = client.add_vector(**param)
|
||||
assert res == Status.OK
|
||||
res, ids = client.add_vectors(**param)
|
||||
assert res == Status.SUCCESS
|
||||
|
||||
def test_false_add_vector(self, client):
|
||||
param ={
|
||||
'table_name': fake.table_name(),
|
||||
'records': [row_record_factory() for _ in range(1000)]
|
||||
'records': row_records_factory(256)
|
||||
}
|
||||
res, ids = client.add_vector(**param)
|
||||
assert res != Status.OK
|
||||
res, ids = client.add_vectors(**param)
|
||||
assert res != Status.SUCCESS
|
||||
|
||||
@mock.patch.object(MegasearchService.Client, 'SearchVector')
|
||||
def test_search_vector(self, SearchVector, client):
|
||||
SearchVector.return_value = None
|
||||
SearchVector.return_value = None, None
|
||||
param = {
|
||||
'table_name': fake.table_name(),
|
||||
'query_records': [row_record_factory() for _ in range(1000)],
|
||||
'top_k': random.randint(0,10)
|
||||
'query_records': row_records_factory(256),
|
||||
'query_ranges': ranges_factory(),
|
||||
'top_k': random.randint(0, 10)
|
||||
}
|
||||
res, results = client.search_vector(**param)
|
||||
assert res == Status.OK
|
||||
res, results = client.search_vectors(**param)
|
||||
assert res == Status.SUCCESS
|
||||
|
||||
def test_false_vector(self, client):
|
||||
param = {
|
||||
'table_name': fake.table_name(),
|
||||
'query_records': [row_record_factory() for _ in range(1000)],
|
||||
'top_k': random.randint(0,10)
|
||||
'query_records': row_records_factory(256),
|
||||
'query_ranges': ranges_factory(),
|
||||
'top_k': random.randint(0, 10)
|
||||
}
|
||||
res, results = client.search_vector(**param)
|
||||
assert res != Status.OK
|
||||
res, results = client.search_vectors(**param)
|
||||
assert res != Status.SUCCESS
|
||||
|
||||
@mock.patch.object(MegasearchService.Client, 'DescribeTable')
|
||||
def test_describe_table(self, DescribeTable, client):
|
||||
@ -247,27 +210,38 @@ class TestVector:
|
||||
|
||||
table_name = fake.table_name()
|
||||
res, table_schema = client.describe_table(table_name)
|
||||
assert res == Status.OK
|
||||
assert isinstance(table_schema, ttypes.TableSchema)
|
||||
assert res == Status.SUCCESS
|
||||
assert isinstance(table_schema, TableSchema)
|
||||
|
||||
def test_false_decribe_table(self, client):
|
||||
table_name = fake.table_name()
|
||||
res, table_schema = client.describe_table(table_name)
|
||||
assert res != Status.OK
|
||||
assert res != Status.SUCCESS
|
||||
assert not table_schema
|
||||
|
||||
@mock.patch.object(MegasearchService.Client, 'ShowTables')
|
||||
def test_show_tables(self, ShowTables, client):
|
||||
ShowTables.return_value = [fake.table_name() for _ in range(10)]
|
||||
ShowTables.return_value = [fake.table_name() for _ in range(10)], None
|
||||
res, tables = client.show_tables()
|
||||
assert res == Status.OK
|
||||
assert res == Status.SUCCESS
|
||||
assert isinstance(tables, list)
|
||||
|
||||
def test_false_show_tables(self, client):
|
||||
res, tables = client.show_tables()
|
||||
assert res != Status.OK
|
||||
assert res != Status.SUCCESS
|
||||
assert not tables
|
||||
|
||||
@mock.patch.object(MegasearchService.Client, 'GetTableRowCount')
|
||||
def test_get_table_row_count(self, GetTableRowCount, client):
|
||||
GetTableRowCount.return_value = 22, None
|
||||
res, count = client.get_table_row_count('fake_table')
|
||||
assert res == Status.SUCCESS
|
||||
|
||||
def test_false_get_table_row_count(self, client):
|
||||
res,count = client.get_table_row_count('fake_table')
|
||||
assert res != Status.SUCCESS
|
||||
assert not count
|
||||
|
||||
def test_client_version(self, client):
|
||||
res = client.client_version()
|
||||
assert res == '0.0.1'
|
||||
@ -275,34 +249,13 @@ class TestVector:
|
||||
|
||||
class TestPrepare:
|
||||
|
||||
def test_column(self):
|
||||
param = {
|
||||
'name': 'test01',
|
||||
'type': ColumnType.DATE
|
||||
}
|
||||
res = Prepare.column(**param)
|
||||
LOGGER.error('{}'.format(res))
|
||||
assert res.name == 'test01'
|
||||
assert res.type == ColumnType.DATE
|
||||
assert isinstance(res, ttypes.Column)
|
||||
|
||||
def test_vector_column(self):
|
||||
param = vector_column_factory()
|
||||
|
||||
res = Prepare.vector_column(**param)
|
||||
LOGGER.error('{}'.format(res))
|
||||
assert isinstance(res, ttypes.VectorColumn)
|
||||
|
||||
def test_table_schema(self):
|
||||
|
||||
vec_params = [vector_column_factory() for i in range(10)]
|
||||
column_params = [column_factory() for i in range(5)]
|
||||
|
||||
param = {
|
||||
'table_name': 'test03',
|
||||
'vector_columns': [Prepare.vector_column(**pa) for pa in vec_params],
|
||||
'attribute_columns': [Prepare.column(**pa) for pa in column_params],
|
||||
'partition_column_names': [str(x) for x in range(2)]
|
||||
'table_name': fake.table_name(),
|
||||
'dimension': random.randint(0, 999),
|
||||
'index_type': IndexType.IDMAP,
|
||||
'store_raw_vector': False
|
||||
}
|
||||
res = Prepare.table_schema(**param)
|
||||
assert isinstance(res, ttypes.TableSchema)
|
||||
@ -319,39 +272,10 @@ class TestPrepare:
|
||||
assert res.start_value == '200'
|
||||
assert res.end_value == '1000'
|
||||
|
||||
def test_create_table_partition_param(self):
|
||||
param = {
|
||||
'table_name': fake.table_name(),
|
||||
'partition_name': fake.table_name(),
|
||||
'column_name_to_range': {fake.name(): range_factory() for _ in range(3)}
|
||||
}
|
||||
res = Prepare.create_table_partition_param(**param)
|
||||
LOGGER.error('{}'.format(res))
|
||||
assert isinstance(res, ttypes.CreateTablePartitionParam)
|
||||
|
||||
def test_delete_table_partition_param(self):
|
||||
param = {
|
||||
'table_name': fake.table_name(),
|
||||
'partition_names': [fake.name() for i in range(5)]
|
||||
}
|
||||
res = Prepare.delete_table_partition_param(**param)
|
||||
assert isinstance(res, ttypes.DeleteTablePartitionParam)
|
||||
|
||||
def test_row_record(self):
|
||||
param={
|
||||
'column_name_to_vector': {fake.name(): [random.random() for i in range(256)]},
|
||||
'column_name_to_attribute': {fake.name(): fake.name()}
|
||||
}
|
||||
res = Prepare.row_record(**param)
|
||||
vec = [random.random() + random.randint(0, 9) for _ in range(256)]
|
||||
bin_vec = struct.pack(str(256) + "d", *vec)
|
||||
res = Prepare.row_record(bin_vec)
|
||||
assert isinstance(res, ttypes.RowRecord)
|
||||
|
||||
def test_query_record(self):
|
||||
param = {
|
||||
'column_name_to_vector': {fake.name(): [random.random() for i in range(256)]},
|
||||
'selected_columns': [fake.name() for _ in range(10)],
|
||||
'name_to_partition_ranges': {fake.name(): [range_factory() for _ in range(5)]}
|
||||
}
|
||||
res = Prepare.query_record(**param)
|
||||
assert isinstance(res, ttypes.QueryRecord)
|
||||
|
||||
assert isinstance(bin_vec, bytes)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user