implement sdk interface part2

Former-commit-id: e0f031025133d35456f685dfb2f4d3768ee99a56
This commit is contained in:
groot 2019-05-28 18:43:08 +08:00
parent 5fd733419b
commit 5c79f1883c
14 changed files with 251 additions and 76 deletions

View File

@ -1,7 +1,7 @@
server_config:
address: 0.0.0.0
port: 33001
transfer_protocol: json #optional: binary, compact, json, debug
transfer_protocol: binary #optional: binary, compact, json
server_mode: thread_pool #optional: simple, thread_pool
gpu_index: 0 #which gpu to be used

View File

@ -1,7 +1,7 @@
server_config:
address: 0.0.0.0
port: 33001
transfer_protocol: json #optional: binary, compact, json, debug
transfer_protocol: binary #optional: binary, compact, json
server_mode: thread_pool #optional: simple, thread_pool
gpu_index: 0 #which gpu to be used

View File

@ -13,13 +13,46 @@
using namespace megasearch;
namespace {
#define BLOCK_SPLITER std::cout << "===========================================" << std::endl;
void PrintTableSchema(const megasearch::TableSchema& tb_schema) {
std::cout << "===========================================" << std::endl;
BLOCK_SPLITER
std::cout << "Table name: " << tb_schema.table_name << std::endl;
std::cout << "Table vectors: " << tb_schema.vector_column_array.size() << std::endl;
std::cout << "Table attributes: " << tb_schema.attribute_column_array.size() << std::endl;
std::cout << "Table partitions: " << tb_schema.partition_column_name_array.size() << std::endl;
std::cout << "===========================================" << std::endl;
BLOCK_SPLITER
}
void PrintRecordIdArray(const std::vector<int64_t>& record_ids) {
BLOCK_SPLITER
std::cout << "Returned id array count: " << record_ids.size() << std::endl;
#if 0
for(auto id : record_ids) {
std::cout << std::to_string(id) << std::endl;
}
#endif
BLOCK_SPLITER
}
void PrintSearchResult(const std::vector<TopKQueryResult>& topk_query_result_array) {
BLOCK_SPLITER
std::cout << "Returned result count: " << topk_query_result_array.size() << std::endl;
int32_t index = 1;
for(auto& result : topk_query_result_array) {
std::cout << "No. " << std::to_string(index) << " vector top k search result:" << std::endl;
for(auto& item : result.query_result_arrays) {
std::cout << "\t" << std::to_string(item.id) << "\tscore:" << std::to_string(item.score);
for(auto& attri : item.column_map) {
std::cout << "\t" << attri.first << ":" << attri.second;
}
std::cout << std::endl;
}
}
BLOCK_SPLITER
}
std::string CurrentTime() {
@ -42,8 +75,29 @@ namespace {
static const std::string TABLE_NAME = GetTableName();
static const std::string VECTOR_COLUMN_NAME = "face_vector";
static const std::string AGE_COLUMN_NAME = "age";
static const std::string CITY_COLUMN_NAME = "city";
static const int64_t TABLE_DIMENSION = 512;
TableSchema BuildTableSchema() {
TableSchema tb_schema;
VectorColumn col1;
col1.name = VECTOR_COLUMN_NAME;
col1.dimension = TABLE_DIMENSION;
col1.store_raw_vector = true;
tb_schema.vector_column_array.emplace_back(col1);
Column col2 = {ColumnType::int8, AGE_COLUMN_NAME};
tb_schema.attribute_column_array.emplace_back(col2);
Column col3 = {ColumnType::int16, CITY_COLUMN_NAME};
tb_schema.attribute_column_array.emplace_back(col3);
tb_schema.table_name = TABLE_NAME;
return tb_schema;
}
void BuildVectors(int64_t from, int64_t to,
std::vector<RowRecord>* vector_record_array,
std::vector<QueryRecord>* query_record_array) {
@ -58,6 +112,19 @@ namespace {
query_record_array->clear();
}
static const std::map<int64_t , std::string> CITY_MAP = {
{0, "Beijing"},
{1, "Shanhai"},
{2, "Hangzhou"},
{3, "Guangzhou"},
{4, "Shenzheng"},
{5, "Wuhan"},
{6, "Chengdu"},
{7, "Chongqin"},
{8, "Tianjing"},
{9, "Hongkong"},
};
for (int64_t k = from; k < to; k++) {
std::vector<float> f_p;
@ -69,12 +136,16 @@ namespace {
if(vector_record_array) {
RowRecord record;
record.vector_map.insert(std::make_pair(VECTOR_COLUMN_NAME, f_p));
record.attribute_map[AGE_COLUMN_NAME] = std::to_string(k%100);
record.attribute_map[CITY_COLUMN_NAME] = CITY_MAP.at(k%CITY_MAP.size());
vector_record_array->emplace_back(record);
}
if(query_record_array) {
QueryRecord record;
record.vector_map.insert(std::make_pair(VECTOR_COLUMN_NAME, f_p));
record.selected_column_array.push_back(AGE_COLUMN_NAME);
record.selected_column_array.push_back(CITY_COLUMN_NAME);
query_record_array->emplace_back(record);
}
}
@ -87,29 +158,30 @@ ClientTest::Test(const std::string& address, const std::string& port) {
ConnectParam param = { address, port };
conn->Connect(param);
{
std::cout << "ShowTables" << std::endl;
std::vector<std::string> tables;
Status stat = conn->ShowTables(tables);
std::cout << "Function call status: " << stat.ToString() << std::endl;
std::cout << "All tables: " << std::endl;
for(auto& table : tables) {
std::cout << "\t" << table << std::endl;
}
}
{//create table
TableSchema tb_schema;
VectorColumn col1;
col1.name = VECTOR_COLUMN_NAME;
col1.dimension = TABLE_DIMENSION;
col1.store_raw_vector = true;
tb_schema.vector_column_array.emplace_back(col1);
Column col2;
col2.name = "age";
tb_schema.attribute_column_array.emplace_back(col2);
tb_schema.table_name = TABLE_NAME;
TableSchema tb_schema = BuildTableSchema();
PrintTableSchema(tb_schema);
std::cout << "CreateTable" << std::endl;
Status stat = conn->CreateTable(tb_schema);
std::cout << "Create table result: " << stat.ToString() << std::endl;
std::cout << "Function call status: " << stat.ToString() << std::endl;
}
{//describe table
TableSchema tb_schema;
std::cout << "DescribeTable" << std::endl;
Status stat = conn->DescribeTable(TABLE_NAME, tb_schema);
std::cout << "Describe table result: " << stat.ToString() << std::endl;
std::cout << "Function call status: " << stat.ToString() << std::endl;
PrintTableSchema(tb_schema);
}
@ -117,10 +189,10 @@ ClientTest::Test(const std::string& address, const std::string& port) {
std::vector<RowRecord> record_array;
BuildVectors(0, 10000, &record_array, nullptr);
std::vector<int64_t> record_ids;
std::cout << "Begin add vectors" << std::endl;
std::cout << "AddVector" << std::endl;
Status stat = conn->AddVector(TABLE_NAME, record_array, record_ids);
std::cout << "Add vector result: " << stat.ToString() << std::endl;
std::cout << "Returned vector ids: " << record_ids.size() << std::endl;
std::cout << "Function call status: " << stat.ToString() << std::endl;
PrintRecordIdArray(record_ids);
}
{//search vectors
@ -129,10 +201,10 @@ ClientTest::Test(const std::string& address, const std::string& port) {
BuildVectors(500, 510, nullptr, &record_array);
std::vector<TopKQueryResult> topk_query_result_array;
std::cout << "Begin search vectors" << std::endl;
std::cout << "SearchVector" << std::endl;
Status stat = conn->SearchVector(TABLE_NAME, record_array, topk_query_result_array, 10);
std::cout << "Search vector result: " << stat.ToString() << std::endl;
std::cout << "Returned result count: " << topk_query_result_array.size() << std::endl;
std::cout << "Function call status: " << stat.ToString() << std::endl;
PrintSearchResult(topk_query_result_array);
}
// {//delete table

View File

@ -21,7 +21,7 @@ ClientProxy::Connect(const ConnectParam &param) {
Disconnect();
int32_t port = atoi(param.port.c_str());
return ClientPtr()->Connect(param.ip_address, port, "json");
return ClientPtr()->Connect(param.ip_address, port, THRIFT_PROTOCOL_BINARY);
}
Status
@ -58,7 +58,7 @@ ClientProxy::Disconnect() {
std::string
ClientProxy::ClientVersion() const {
return std::string("Current Version");
return std::string("v1.0");
}
Status

View File

@ -50,14 +50,12 @@ ThriftClient::Connect(const std::string& address, int32_t port, const std::strin
stdcxx::shared_ptr<TSocket> socket_ptr(new transport::TSocket(address, port));
stdcxx::shared_ptr<TTransport> transport_ptr(new TBufferedTransport(socket_ptr));
stdcxx::shared_ptr<TProtocol> protocol_ptr;
if(protocol == "binary") {
if(protocol == THRIFT_PROTOCOL_BINARY) {
protocol_ptr.reset(new TBinaryProtocol(transport_ptr));
} else if(protocol == "json") {
} else if(protocol == THRIFT_PROTOCOL_JSON) {
protocol_ptr.reset(new TJSONProtocol(transport_ptr));
} else if(protocol == "compact") {
} else if(protocol == THRIFT_PROTOCOL_COMPACT) {
protocol_ptr.reset(new TCompactProtocol(transport_ptr));
} else if(protocol == "debug") {
protocol_ptr.reset(new TDebugProtocol(transport_ptr));
} else {
//CLIENT_LOG_ERROR << "Service protocol: " << protocol << " is not supported currently";
return Status(StatusCode::Invalid, "unsupported protocol");

View File

@ -14,6 +14,10 @@ namespace megasearch {
using MegasearchServiceClientPtr = std::shared_ptr<megasearch::thrift::MegasearchServiceClient>;
static const std::string THRIFT_PROTOCOL_JSON = "json";
static const std::string THRIFT_PROTOCOL_BINARY = "binary";
static const std::string THRIFT_PROTOCOL_COMPACT = "compact";
class ThriftClient {
public:
ThriftClient();

View File

@ -67,14 +67,15 @@ MegasearchServiceHandler::DescribeTable(thrift::TableSchema &_return, const std:
void
MegasearchServiceHandler::ShowTables(std::vector<std::string> &_return) {
// Your implementation goes here
printf("ShowTables\n");
BaseTaskPtr task_ptr = ShowTablesTask::Create(_return);
MegasearchScheduler::ExecTask(task_ptr);
}
void
MegasearchServiceHandler::Ping(std::string& _return, const std::string& cmd) {
// Your implementation goes here
printf("Ping\n");
if(cmd == "version") {
_return = "v1.2.0";
}
}
}

View File

@ -54,7 +54,6 @@ MegasearchServer::StartService() {
stdcxx::shared_ptr<TServerTransport> server_transport(new TServerSocket(address, port));
stdcxx::shared_ptr<TTransportFactory> transport_factory(new TBufferedTransportFactory());
std::string protocol = "json";
stdcxx::shared_ptr<TProtocolFactory> protocol_factory;
if (protocol == "binary") {
protocol_factory.reset(new TBinaryProtocolFactory());
@ -62,8 +61,6 @@ MegasearchServer::StartService() {
protocol_factory.reset(new TJSONProtocolFactory());
} else if (protocol == "compact") {
protocol_factory.reset(new TCompactProtocolFactory());
} else if (protocol == "debug") {
protocol_factory.reset(new TDebugProtocolFactory());
} else {
//SERVER_LOG_INFO << "Service protocol: " << protocol << " is not supported currently";
return;

View File

@ -21,6 +21,7 @@ namespace server {
static const std::string DQL_TASK_GROUP = "dql";
static const std::string DDL_DML_TASK_GROUP = "ddl_dml";
static const std::string PING_TASK_GROUP = "ping";
static const std::string VECTOR_UID = "uid";
static const uint64_t USE_MT = 5000;
@ -48,6 +49,10 @@ namespace {
}
}
~DBWrapper() {
delete db_;
}
zilliz::vecwise::engine::DB* DB() { return db_; }
private:
@ -78,17 +83,17 @@ BaseTaskPtr CreateTableTask::Create(const thrift::TableSchema& schema) {
ServerError CreateTableTask::OnExecute() {
TimeRecorder rc("CreateTableTask");
try {
if(schema_.vector_column_array.empty()) {
return SERVER_INVALID_ARGUMENT;
}
IVecIdMapper::GetInstance()->AddGroup(schema_.table_name);
engine::meta::TableSchema table_schema;
table_schema.dimension = (uint16_t)schema_.vector_column_array[0].dimension;
table_schema.table_id = schema_.table_name;
engine::Status stat = DB()->CreateTable(table_schema);
engine::meta::TableSchema table_info;
table_info.dimension = (uint16_t)schema_.vector_column_array[0].dimension;
table_info.table_id = schema_.table_name;
engine::Status stat = DB()->CreateTable(table_info);
if(!stat.ok()) {//could exist
error_msg_ = "Engine failed: " + stat.ToString();
SERVER_LOG_ERROR << error_msg_;
@ -109,7 +114,7 @@ ServerError CreateTableTask::OnExecute() {
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
DescribeTableTask::DescribeTableTask(const std::string &table_name, thrift::TableSchema &schema)
: BaseTask(DDL_DML_TASK_GROUP),
: BaseTask(PING_TASK_GROUP),
table_name_(table_name),
schema_(schema) {
schema_.table_name = table_name_;
@ -123,9 +128,9 @@ ServerError DescribeTableTask::OnExecute() {
TimeRecorder rc("DescribeTableTask");
try {
engine::meta::TableSchema table_schema;
table_schema.table_id = table_name_;
engine::Status stat = DB()->DescribeTable(table_schema);
engine::meta::TableSchema table_info;
table_info.table_id = table_name_;
engine::Status stat = DB()->DescribeTable(table_info);
if(!stat.ok()) {
error_code_ = SERVER_GROUP_NOT_EXIST;
error_msg_ = "Engine failed: " + stat.ToString();
@ -154,8 +159,8 @@ DeleteTableTask::DeleteTableTask(const std::string& table_name)
}
BaseTaskPtr DeleteTableTask::Create(const std::string& table_id) {
return std::shared_ptr<BaseTask>(new DeleteTableTask(table_id));
BaseTaskPtr DeleteTableTask::Create(const std::string& group_id) {
return std::shared_ptr<BaseTask>(new DeleteTableTask(group_id));
}
ServerError DeleteTableTask::OnExecute() {
@ -168,6 +173,22 @@ ServerError DeleteTableTask::OnExecute() {
return SERVER_NOT_IMPLEMENT;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
ShowTablesTask::ShowTablesTask(std::vector<std::string>& tables)
: BaseTask(PING_TASK_GROUP),
tables_(tables) {
}
BaseTaskPtr ShowTablesTask::Create(std::vector<std::string>& tables) {
return std::shared_ptr<BaseTask>(new ShowTablesTask(tables));
}
ServerError ShowTablesTask::OnExecute() {
IVecIdMapper::GetInstance()->AllGroups(tables_);
return SERVER_SUCCESS;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
AddVectorTask::AddVectorTask(const std::string& table_name,
@ -195,9 +216,9 @@ ServerError AddVectorTask::OnExecute() {
return SERVER_SUCCESS;
}
engine::meta::TableSchema table_schema;
table_schema.table_id = table_name_;
engine::Status stat = DB()->DescribeTable(table_schema);
engine::meta::TableSchema table_info;
table_info.table_id = table_name_;
engine::Status stat = DB()->DescribeTable(table_info);
if(!stat.ok()) {
error_code_ = SERVER_GROUP_NOT_EXIST;
error_msg_ = "Engine failed: " + stat.ToString();
@ -208,7 +229,7 @@ ServerError AddVectorTask::OnExecute() {
rc.Record("get group info");
uint64_t vec_count = (uint64_t)record_array_.size();
uint64_t group_dim = table_schema.dimension;
uint64_t group_dim = table_info.dimension;
std::vector<float> vec_f;
vec_f.resize(vec_count*group_dim);//allocate enough memory
for(uint64_t i = 0; i < vec_count; i++) {
@ -228,6 +249,7 @@ ServerError AddVectorTask::OnExecute() {
return error_code_;
}
//convert double array to float array(thrift has no float type)
const double* d_p = reinterpret_cast<const double*>(record.vector_map.begin()->second.data());
for(uint64_t d = 0; d < vec_dim; d++) {
vec_f[i*vec_dim + d] = (float)(d_p[d]);
@ -245,12 +267,27 @@ ServerError AddVectorTask::OnExecute() {
return error_code_;
}
if(record_ids_.size() < vec_count) {
if(record_ids_.size() != vec_count) {
SERVER_LOG_ERROR << "Vector ID not returned";
return SERVER_UNEXPECTED_ERROR;
}
rc.Record("done");
//persist attributes
for(uint64_t i = 0; i < vec_count; i++) {
const auto &record = record_array_[i];
//any attributes?
if(record.attribute_map.empty()) {
continue;
}
std::string nid = std::to_string(record_ids_[i]);
std::string attrib_str;
AttributeSerializer::Encode(record.attribute_map, attrib_str);
IVecIdMapper::GetInstance()->Put(nid, attrib_str, table_name_);
}
rc.Record("persist vector attributes");
} catch (std::exception& ex) {
error_code_ = SERVER_UNEXPECTED_ERROR;
@ -293,9 +330,9 @@ ServerError SearchVectorTask::OnExecute() {
return error_code_;
}
engine::meta::TableSchema table_schema;
table_schema.table_id = table_name_;
engine::Status stat = DB()->DescribeTable(table_schema);
engine::meta::TableSchema table_info;
table_info.table_id = table_name_;
engine::Status stat = DB()->DescribeTable(table_info);
if(!stat.ok()) {
error_code_ = SERVER_GROUP_NOT_EXIST;
error_msg_ = "Engine failed: " + stat.ToString();
@ -305,7 +342,7 @@ ServerError SearchVectorTask::OnExecute() {
std::vector<float> vec_f;
uint64_t record_count = (uint64_t)record_array_.size();
vec_f.resize(record_count*table_schema.dimension);
vec_f.resize(record_count*table_info.dimension);
for(uint64_t i = 0; i < record_array_.size(); i++) {
const auto& record = record_array_[i];
@ -317,14 +354,15 @@ ServerError SearchVectorTask::OnExecute() {
}
uint64_t vec_dim = record.vector_map.begin()->second.size() / sizeof(double);//how many double value?
if (vec_dim != table_schema.dimension) {
if (vec_dim != table_info.dimension) {
SERVER_LOG_ERROR << "Invalid vector dimension: " << vec_dim
<< " vs. group dimension:" << table_schema.dimension;
<< " vs. group dimension:" << table_info.dimension;
error_code_ = SERVER_INVALID_VECTOR_DIMENSION;
error_msg_ = "Engine failed: " + stat.ToString();
return error_code_;
}
//convert double array to float array(thrift has no float type)
const double* d_p = reinterpret_cast<const double*>(record.vector_map.begin()->second.data());
for(uint64_t d = 0; d < vec_dim; d++) {
vec_f[i*vec_dim + d] = (float)(d_p[d]);
@ -336,25 +374,50 @@ ServerError SearchVectorTask::OnExecute() {
std::vector<DB_DATE> dates;
engine::QueryResults results;
stat = DB()->Query(table_name_, (size_t)top_k_, record_count, vec_f.data(), dates, results);
rc.Record("search vectors from engine");
if(!stat.ok()) {
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
return SERVER_UNEXPECTED_ERROR;
} else {
rc.Record("do searching");
for(engine::QueryResult& result : results){
thrift::TopKQueryResult thrift_topk_result;
for(auto id : result) {
thrift::QueryResult thrift_result;
thrift_result.__set_id(id);
thrift_topk_result.query_result_arrays.emplace_back(thrift_result);
}
result_array_.emplace_back(thrift_topk_result);
}
rc.Record("construct result");
}
rc.Record("done");
if(results.size() != record_count) {
SERVER_LOG_ERROR << "Search result not returned";
return SERVER_UNEXPECTED_ERROR;
}
//construct result array
for(uint64_t i = 0; i < record_count; i++) {
auto& result = results[i];
const auto& record = record_array_[i];
thrift::TopKQueryResult thrift_topk_result;
for(auto id : result) {
thrift::QueryResult thrift_result;
thrift_result.__set_id(id);
//need get attributes?
if(record.selected_column_array.empty()) {
thrift_topk_result.query_result_arrays.emplace_back(thrift_result);
continue;
}
std::string nid = std::to_string(id);
std::string attrib_str;
IVecIdMapper::GetInstance()->Get(nid, attrib_str, table_name_);
AttribMap attrib_map;
AttributeSerializer::Decode(attrib_str, attrib_map);
for(auto& attri : record.selected_column_array) {
thrift_result.column_map[attri] = attrib_map[attri];
}
thrift_topk_result.query_result_arrays.emplace_back(thrift_result);
}
result_array_.emplace_back(thrift_topk_result);
}
rc.Record("construct result");
} catch (std::exception& ex) {
error_code_ = SERVER_UNEXPECTED_ERROR;

View File

@ -65,6 +65,20 @@ private:
std::string table_name_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class ShowTablesTask : public BaseTask {
public:
static BaseTaskPtr Create(std::vector<std::string>& tables);
protected:
ShowTablesTask(std::vector<std::string>& tables);
ServerError OnExecute() override;
private:
std::vector<std::string>& tables_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class AddVectorTask : public BaseTask {
public:

View File

@ -108,6 +108,19 @@ bool RocksIdMapper::IsGroupExist(const std::string& group) const {
return IsGroupExistInternal(group);
}
ServerError RocksIdMapper::AllGroups(std::vector<std::string>& groups) const {
groups.clear();
std::lock_guard<std::mutex> lck(db_mutex_);
for(auto& pair : column_handles_) {
if(pair.first == ROCKSDB_DEFAULT_GROUP) {
continue;
}
groups.push_back(pair.first);
}
return SERVER_SUCCESS;
}
ServerError RocksIdMapper::Put(const std::string& nid, const std::string& sid, const std::string& group) {
std::lock_guard<std::mutex> lck(db_mutex_);

View File

@ -26,6 +26,7 @@ public:
ServerError AddGroup(const std::string& group) override;
bool IsGroupExist(const std::string& group) const override;
ServerError AllGroups(std::vector<std::string>& groups) const override;
ServerError Put(const std::string& nid, const std::string& sid, const std::string& group = "") override;
ServerError Put(const std::vector<std::string>& nid, const std::vector<std::string>& sid, const std::string& group = "") override;

View File

@ -50,6 +50,16 @@ bool SimpleIdMapper::IsGroupExist(const std::string& group) const {
return id_groups_.count(group) > 0;
}
ServerError SimpleIdMapper::AllGroups(std::vector<std::string>& groups) const {
groups.clear();
for(auto& pair : id_groups_) {
groups.push_back(pair.first);
}
return SERVER_SUCCESS;
}
//not thread-safe
ServerError SimpleIdMapper::Put(const std::string& nid, const std::string& sid, const std::string& group) {
ID_MAPPING& mapping = id_groups_[group];

View File

@ -27,6 +27,7 @@ public:
virtual ServerError AddGroup(const std::string& group) = 0;
virtual bool IsGroupExist(const std::string& group) const = 0;
virtual ServerError AllGroups(std::vector<std::string>& groups) const = 0;
virtual ServerError Put(const std::string& nid, const std::string& sid, const std::string& group = "") = 0;
virtual ServerError Put(const std::vector<std::string>& nid, const std::vector<std::string>& sid, const std::string& group = "") = 0;
@ -46,6 +47,7 @@ public:
ServerError AddGroup(const std::string& group) override;
bool IsGroupExist(const std::string& group) const override;
ServerError AllGroups(std::vector<std::string>& groups) const override;
ServerError Put(const std::string& nid, const std::string& sid, const std::string& group = "") override;
ServerError Put(const std::vector<std::string>& nid, const std::vector<std::string>& sid, const std::string& group = "") override;