mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-05 05:18:52 +08:00
implement sdk interface part2
Former-commit-id: e0f031025133d35456f685dfb2f4d3768ee99a56
This commit is contained in:
parent
5fd733419b
commit
5c79f1883c
@ -1,7 +1,7 @@
|
||||
server_config:
|
||||
address: 0.0.0.0
|
||||
port: 33001
|
||||
transfer_protocol: json #optional: binary, compact, json, debug
|
||||
transfer_protocol: binary #optional: binary, compact, json
|
||||
server_mode: thread_pool #optional: simple, thread_pool
|
||||
gpu_index: 0 #which gpu to be used
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
server_config:
|
||||
address: 0.0.0.0
|
||||
port: 33001
|
||||
transfer_protocol: json #optional: binary, compact, json, debug
|
||||
transfer_protocol: binary #optional: binary, compact, json
|
||||
server_mode: thread_pool #optional: simple, thread_pool
|
||||
gpu_index: 0 #which gpu to be used
|
||||
|
||||
|
@ -13,13 +13,46 @@
|
||||
using namespace megasearch;
|
||||
|
||||
namespace {
|
||||
|
||||
#define BLOCK_SPLITER std::cout << "===========================================" << std::endl;
|
||||
|
||||
void PrintTableSchema(const megasearch::TableSchema& tb_schema) {
|
||||
std::cout << "===========================================" << std::endl;
|
||||
BLOCK_SPLITER
|
||||
std::cout << "Table name: " << tb_schema.table_name << std::endl;
|
||||
std::cout << "Table vectors: " << tb_schema.vector_column_array.size() << std::endl;
|
||||
std::cout << "Table attributes: " << tb_schema.attribute_column_array.size() << std::endl;
|
||||
std::cout << "Table partitions: " << tb_schema.partition_column_name_array.size() << std::endl;
|
||||
std::cout << "===========================================" << std::endl;
|
||||
BLOCK_SPLITER
|
||||
}
|
||||
|
||||
void PrintRecordIdArray(const std::vector<int64_t>& record_ids) {
|
||||
BLOCK_SPLITER
|
||||
std::cout << "Returned id array count: " << record_ids.size() << std::endl;
|
||||
#if 0
|
||||
for(auto id : record_ids) {
|
||||
std::cout << std::to_string(id) << std::endl;
|
||||
}
|
||||
#endif
|
||||
BLOCK_SPLITER
|
||||
}
|
||||
|
||||
void PrintSearchResult(const std::vector<TopKQueryResult>& topk_query_result_array) {
|
||||
BLOCK_SPLITER
|
||||
std::cout << "Returned result count: " << topk_query_result_array.size() << std::endl;
|
||||
|
||||
int32_t index = 1;
|
||||
for(auto& result : topk_query_result_array) {
|
||||
std::cout << "No. " << std::to_string(index) << " vector top k search result:" << std::endl;
|
||||
for(auto& item : result.query_result_arrays) {
|
||||
std::cout << "\t" << std::to_string(item.id) << "\tscore:" << std::to_string(item.score);
|
||||
for(auto& attri : item.column_map) {
|
||||
std::cout << "\t" << attri.first << ":" << attri.second;
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
BLOCK_SPLITER
|
||||
}
|
||||
|
||||
std::string CurrentTime() {
|
||||
@ -42,8 +75,29 @@ namespace {
|
||||
|
||||
static const std::string TABLE_NAME = GetTableName();
|
||||
static const std::string VECTOR_COLUMN_NAME = "face_vector";
|
||||
static const std::string AGE_COLUMN_NAME = "age";
|
||||
static const std::string CITY_COLUMN_NAME = "city";
|
||||
static const int64_t TABLE_DIMENSION = 512;
|
||||
|
||||
TableSchema BuildTableSchema() {
|
||||
TableSchema tb_schema;
|
||||
VectorColumn col1;
|
||||
col1.name = VECTOR_COLUMN_NAME;
|
||||
col1.dimension = TABLE_DIMENSION;
|
||||
col1.store_raw_vector = true;
|
||||
tb_schema.vector_column_array.emplace_back(col1);
|
||||
|
||||
Column col2 = {ColumnType::int8, AGE_COLUMN_NAME};
|
||||
tb_schema.attribute_column_array.emplace_back(col2);
|
||||
|
||||
Column col3 = {ColumnType::int16, CITY_COLUMN_NAME};
|
||||
tb_schema.attribute_column_array.emplace_back(col3);
|
||||
|
||||
tb_schema.table_name = TABLE_NAME;
|
||||
|
||||
return tb_schema;
|
||||
}
|
||||
|
||||
void BuildVectors(int64_t from, int64_t to,
|
||||
std::vector<RowRecord>* vector_record_array,
|
||||
std::vector<QueryRecord>* query_record_array) {
|
||||
@ -58,6 +112,19 @@ namespace {
|
||||
query_record_array->clear();
|
||||
}
|
||||
|
||||
static const std::map<int64_t , std::string> CITY_MAP = {
|
||||
{0, "Beijing"},
|
||||
{1, "Shanhai"},
|
||||
{2, "Hangzhou"},
|
||||
{3, "Guangzhou"},
|
||||
{4, "Shenzheng"},
|
||||
{5, "Wuhan"},
|
||||
{6, "Chengdu"},
|
||||
{7, "Chongqin"},
|
||||
{8, "Tianjing"},
|
||||
{9, "Hongkong"},
|
||||
};
|
||||
|
||||
for (int64_t k = from; k < to; k++) {
|
||||
|
||||
std::vector<float> f_p;
|
||||
@ -69,12 +136,16 @@ namespace {
|
||||
if(vector_record_array) {
|
||||
RowRecord record;
|
||||
record.vector_map.insert(std::make_pair(VECTOR_COLUMN_NAME, f_p));
|
||||
record.attribute_map[AGE_COLUMN_NAME] = std::to_string(k%100);
|
||||
record.attribute_map[CITY_COLUMN_NAME] = CITY_MAP.at(k%CITY_MAP.size());
|
||||
vector_record_array->emplace_back(record);
|
||||
}
|
||||
|
||||
if(query_record_array) {
|
||||
QueryRecord record;
|
||||
record.vector_map.insert(std::make_pair(VECTOR_COLUMN_NAME, f_p));
|
||||
record.selected_column_array.push_back(AGE_COLUMN_NAME);
|
||||
record.selected_column_array.push_back(CITY_COLUMN_NAME);
|
||||
query_record_array->emplace_back(record);
|
||||
}
|
||||
}
|
||||
@ -87,29 +158,30 @@ ClientTest::Test(const std::string& address, const std::string& port) {
|
||||
ConnectParam param = { address, port };
|
||||
conn->Connect(param);
|
||||
|
||||
{
|
||||
std::cout << "ShowTables" << std::endl;
|
||||
std::vector<std::string> tables;
|
||||
Status stat = conn->ShowTables(tables);
|
||||
std::cout << "Function call status: " << stat.ToString() << std::endl;
|
||||
std::cout << "All tables: " << std::endl;
|
||||
for(auto& table : tables) {
|
||||
std::cout << "\t" << table << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
{//create table
|
||||
TableSchema tb_schema;
|
||||
VectorColumn col1;
|
||||
col1.name = VECTOR_COLUMN_NAME;
|
||||
col1.dimension = TABLE_DIMENSION;
|
||||
col1.store_raw_vector = true;
|
||||
tb_schema.vector_column_array.emplace_back(col1);
|
||||
|
||||
Column col2;
|
||||
col2.name = "age";
|
||||
tb_schema.attribute_column_array.emplace_back(col2);
|
||||
|
||||
tb_schema.table_name = TABLE_NAME;
|
||||
|
||||
TableSchema tb_schema = BuildTableSchema();
|
||||
PrintTableSchema(tb_schema);
|
||||
std::cout << "CreateTable" << std::endl;
|
||||
Status stat = conn->CreateTable(tb_schema);
|
||||
std::cout << "Create table result: " << stat.ToString() << std::endl;
|
||||
std::cout << "Function call status: " << stat.ToString() << std::endl;
|
||||
}
|
||||
|
||||
{//describe table
|
||||
TableSchema tb_schema;
|
||||
std::cout << "DescribeTable" << std::endl;
|
||||
Status stat = conn->DescribeTable(TABLE_NAME, tb_schema);
|
||||
std::cout << "Describe table result: " << stat.ToString() << std::endl;
|
||||
std::cout << "Function call status: " << stat.ToString() << std::endl;
|
||||
PrintTableSchema(tb_schema);
|
||||
}
|
||||
|
||||
@ -117,10 +189,10 @@ ClientTest::Test(const std::string& address, const std::string& port) {
|
||||
std::vector<RowRecord> record_array;
|
||||
BuildVectors(0, 10000, &record_array, nullptr);
|
||||
std::vector<int64_t> record_ids;
|
||||
std::cout << "Begin add vectors" << std::endl;
|
||||
std::cout << "AddVector" << std::endl;
|
||||
Status stat = conn->AddVector(TABLE_NAME, record_array, record_ids);
|
||||
std::cout << "Add vector result: " << stat.ToString() << std::endl;
|
||||
std::cout << "Returned vector ids: " << record_ids.size() << std::endl;
|
||||
std::cout << "Function call status: " << stat.ToString() << std::endl;
|
||||
PrintRecordIdArray(record_ids);
|
||||
}
|
||||
|
||||
{//search vectors
|
||||
@ -129,10 +201,10 @@ ClientTest::Test(const std::string& address, const std::string& port) {
|
||||
BuildVectors(500, 510, nullptr, &record_array);
|
||||
|
||||
std::vector<TopKQueryResult> topk_query_result_array;
|
||||
std::cout << "Begin search vectors" << std::endl;
|
||||
std::cout << "SearchVector" << std::endl;
|
||||
Status stat = conn->SearchVector(TABLE_NAME, record_array, topk_query_result_array, 10);
|
||||
std::cout << "Search vector result: " << stat.ToString() << std::endl;
|
||||
std::cout << "Returned result count: " << topk_query_result_array.size() << std::endl;
|
||||
std::cout << "Function call status: " << stat.ToString() << std::endl;
|
||||
PrintSearchResult(topk_query_result_array);
|
||||
}
|
||||
|
||||
// {//delete table
|
||||
|
@ -21,7 +21,7 @@ ClientProxy::Connect(const ConnectParam ¶m) {
|
||||
Disconnect();
|
||||
|
||||
int32_t port = atoi(param.port.c_str());
|
||||
return ClientPtr()->Connect(param.ip_address, port, "json");
|
||||
return ClientPtr()->Connect(param.ip_address, port, THRIFT_PROTOCOL_BINARY);
|
||||
}
|
||||
|
||||
Status
|
||||
@ -58,7 +58,7 @@ ClientProxy::Disconnect() {
|
||||
|
||||
std::string
|
||||
ClientProxy::ClientVersion() const {
|
||||
return std::string("Current Version");
|
||||
return std::string("v1.0");
|
||||
}
|
||||
|
||||
Status
|
||||
|
@ -50,14 +50,12 @@ ThriftClient::Connect(const std::string& address, int32_t port, const std::strin
|
||||
stdcxx::shared_ptr<TSocket> socket_ptr(new transport::TSocket(address, port));
|
||||
stdcxx::shared_ptr<TTransport> transport_ptr(new TBufferedTransport(socket_ptr));
|
||||
stdcxx::shared_ptr<TProtocol> protocol_ptr;
|
||||
if(protocol == "binary") {
|
||||
if(protocol == THRIFT_PROTOCOL_BINARY) {
|
||||
protocol_ptr.reset(new TBinaryProtocol(transport_ptr));
|
||||
} else if(protocol == "json") {
|
||||
} else if(protocol == THRIFT_PROTOCOL_JSON) {
|
||||
protocol_ptr.reset(new TJSONProtocol(transport_ptr));
|
||||
} else if(protocol == "compact") {
|
||||
} else if(protocol == THRIFT_PROTOCOL_COMPACT) {
|
||||
protocol_ptr.reset(new TCompactProtocol(transport_ptr));
|
||||
} else if(protocol == "debug") {
|
||||
protocol_ptr.reset(new TDebugProtocol(transport_ptr));
|
||||
} else {
|
||||
//CLIENT_LOG_ERROR << "Service protocol: " << protocol << " is not supported currently";
|
||||
return Status(StatusCode::Invalid, "unsupported protocol");
|
||||
|
@ -14,6 +14,10 @@ namespace megasearch {
|
||||
|
||||
using MegasearchServiceClientPtr = std::shared_ptr<megasearch::thrift::MegasearchServiceClient>;
|
||||
|
||||
static const std::string THRIFT_PROTOCOL_JSON = "json";
|
||||
static const std::string THRIFT_PROTOCOL_BINARY = "binary";
|
||||
static const std::string THRIFT_PROTOCOL_COMPACT = "compact";
|
||||
|
||||
class ThriftClient {
|
||||
public:
|
||||
ThriftClient();
|
||||
|
@ -67,14 +67,15 @@ MegasearchServiceHandler::DescribeTable(thrift::TableSchema &_return, const std:
|
||||
|
||||
void
|
||||
MegasearchServiceHandler::ShowTables(std::vector<std::string> &_return) {
|
||||
// Your implementation goes here
|
||||
printf("ShowTables\n");
|
||||
BaseTaskPtr task_ptr = ShowTablesTask::Create(_return);
|
||||
MegasearchScheduler::ExecTask(task_ptr);
|
||||
}
|
||||
|
||||
void
|
||||
MegasearchServiceHandler::Ping(std::string& _return, const std::string& cmd) {
|
||||
// Your implementation goes here
|
||||
printf("Ping\n");
|
||||
if(cmd == "version") {
|
||||
_return = "v1.2.0";
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -54,7 +54,6 @@ MegasearchServer::StartService() {
|
||||
stdcxx::shared_ptr<TServerTransport> server_transport(new TServerSocket(address, port));
|
||||
stdcxx::shared_ptr<TTransportFactory> transport_factory(new TBufferedTransportFactory());
|
||||
|
||||
std::string protocol = "json";
|
||||
stdcxx::shared_ptr<TProtocolFactory> protocol_factory;
|
||||
if (protocol == "binary") {
|
||||
protocol_factory.reset(new TBinaryProtocolFactory());
|
||||
@ -62,8 +61,6 @@ MegasearchServer::StartService() {
|
||||
protocol_factory.reset(new TJSONProtocolFactory());
|
||||
} else if (protocol == "compact") {
|
||||
protocol_factory.reset(new TCompactProtocolFactory());
|
||||
} else if (protocol == "debug") {
|
||||
protocol_factory.reset(new TDebugProtocolFactory());
|
||||
} else {
|
||||
//SERVER_LOG_INFO << "Service protocol: " << protocol << " is not supported currently";
|
||||
return;
|
||||
|
@ -21,6 +21,7 @@ namespace server {
|
||||
|
||||
static const std::string DQL_TASK_GROUP = "dql";
|
||||
static const std::string DDL_DML_TASK_GROUP = "ddl_dml";
|
||||
static const std::string PING_TASK_GROUP = "ping";
|
||||
|
||||
static const std::string VECTOR_UID = "uid";
|
||||
static const uint64_t USE_MT = 5000;
|
||||
@ -48,6 +49,10 @@ namespace {
|
||||
}
|
||||
}
|
||||
|
||||
~DBWrapper() {
|
||||
delete db_;
|
||||
}
|
||||
|
||||
zilliz::vecwise::engine::DB* DB() { return db_; }
|
||||
|
||||
private:
|
||||
@ -78,17 +83,17 @@ BaseTaskPtr CreateTableTask::Create(const thrift::TableSchema& schema) {
|
||||
|
||||
ServerError CreateTableTask::OnExecute() {
|
||||
TimeRecorder rc("CreateTableTask");
|
||||
|
||||
|
||||
try {
|
||||
if(schema_.vector_column_array.empty()) {
|
||||
return SERVER_INVALID_ARGUMENT;
|
||||
}
|
||||
|
||||
IVecIdMapper::GetInstance()->AddGroup(schema_.table_name);
|
||||
engine::meta::TableSchema table_schema;
|
||||
table_schema.dimension = (uint16_t)schema_.vector_column_array[0].dimension;
|
||||
table_schema.table_id = schema_.table_name;
|
||||
engine::Status stat = DB()->CreateTable(table_schema);
|
||||
engine::meta::TableSchema table_info;
|
||||
table_info.dimension = (uint16_t)schema_.vector_column_array[0].dimension;
|
||||
table_info.table_id = schema_.table_name;
|
||||
engine::Status stat = DB()->CreateTable(table_info);
|
||||
if(!stat.ok()) {//could exist
|
||||
error_msg_ = "Engine failed: " + stat.ToString();
|
||||
SERVER_LOG_ERROR << error_msg_;
|
||||
@ -109,7 +114,7 @@ ServerError CreateTableTask::OnExecute() {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
DescribeTableTask::DescribeTableTask(const std::string &table_name, thrift::TableSchema &schema)
|
||||
: BaseTask(DDL_DML_TASK_GROUP),
|
||||
: BaseTask(PING_TASK_GROUP),
|
||||
table_name_(table_name),
|
||||
schema_(schema) {
|
||||
schema_.table_name = table_name_;
|
||||
@ -123,9 +128,9 @@ ServerError DescribeTableTask::OnExecute() {
|
||||
TimeRecorder rc("DescribeTableTask");
|
||||
|
||||
try {
|
||||
engine::meta::TableSchema table_schema;
|
||||
table_schema.table_id = table_name_;
|
||||
engine::Status stat = DB()->DescribeTable(table_schema);
|
||||
engine::meta::TableSchema table_info;
|
||||
table_info.table_id = table_name_;
|
||||
engine::Status stat = DB()->DescribeTable(table_info);
|
||||
if(!stat.ok()) {
|
||||
error_code_ = SERVER_GROUP_NOT_EXIST;
|
||||
error_msg_ = "Engine failed: " + stat.ToString();
|
||||
@ -154,8 +159,8 @@ DeleteTableTask::DeleteTableTask(const std::string& table_name)
|
||||
|
||||
}
|
||||
|
||||
BaseTaskPtr DeleteTableTask::Create(const std::string& table_id) {
|
||||
return std::shared_ptr<BaseTask>(new DeleteTableTask(table_id));
|
||||
BaseTaskPtr DeleteTableTask::Create(const std::string& group_id) {
|
||||
return std::shared_ptr<BaseTask>(new DeleteTableTask(group_id));
|
||||
}
|
||||
|
||||
ServerError DeleteTableTask::OnExecute() {
|
||||
@ -168,6 +173,22 @@ ServerError DeleteTableTask::OnExecute() {
|
||||
return SERVER_NOT_IMPLEMENT;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
ShowTablesTask::ShowTablesTask(std::vector<std::string>& tables)
|
||||
: BaseTask(PING_TASK_GROUP),
|
||||
tables_(tables) {
|
||||
|
||||
}
|
||||
|
||||
BaseTaskPtr ShowTablesTask::Create(std::vector<std::string>& tables) {
|
||||
return std::shared_ptr<BaseTask>(new ShowTablesTask(tables));
|
||||
}
|
||||
|
||||
ServerError ShowTablesTask::OnExecute() {
|
||||
IVecIdMapper::GetInstance()->AllGroups(tables_);
|
||||
|
||||
return SERVER_SUCCESS;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
AddVectorTask::AddVectorTask(const std::string& table_name,
|
||||
@ -195,9 +216,9 @@ ServerError AddVectorTask::OnExecute() {
|
||||
return SERVER_SUCCESS;
|
||||
}
|
||||
|
||||
engine::meta::TableSchema table_schema;
|
||||
table_schema.table_id = table_name_;
|
||||
engine::Status stat = DB()->DescribeTable(table_schema);
|
||||
engine::meta::TableSchema table_info;
|
||||
table_info.table_id = table_name_;
|
||||
engine::Status stat = DB()->DescribeTable(table_info);
|
||||
if(!stat.ok()) {
|
||||
error_code_ = SERVER_GROUP_NOT_EXIST;
|
||||
error_msg_ = "Engine failed: " + stat.ToString();
|
||||
@ -208,7 +229,7 @@ ServerError AddVectorTask::OnExecute() {
|
||||
rc.Record("get group info");
|
||||
|
||||
uint64_t vec_count = (uint64_t)record_array_.size();
|
||||
uint64_t group_dim = table_schema.dimension;
|
||||
uint64_t group_dim = table_info.dimension;
|
||||
std::vector<float> vec_f;
|
||||
vec_f.resize(vec_count*group_dim);//allocate enough memory
|
||||
for(uint64_t i = 0; i < vec_count; i++) {
|
||||
@ -228,6 +249,7 @@ ServerError AddVectorTask::OnExecute() {
|
||||
return error_code_;
|
||||
}
|
||||
|
||||
//convert double array to float array(thrift has no float type)
|
||||
const double* d_p = reinterpret_cast<const double*>(record.vector_map.begin()->second.data());
|
||||
for(uint64_t d = 0; d < vec_dim; d++) {
|
||||
vec_f[i*vec_dim + d] = (float)(d_p[d]);
|
||||
@ -245,12 +267,27 @@ ServerError AddVectorTask::OnExecute() {
|
||||
return error_code_;
|
||||
}
|
||||
|
||||
if(record_ids_.size() < vec_count) {
|
||||
if(record_ids_.size() != vec_count) {
|
||||
SERVER_LOG_ERROR << "Vector ID not returned";
|
||||
return SERVER_UNEXPECTED_ERROR;
|
||||
}
|
||||
|
||||
rc.Record("done");
|
||||
//persist attributes
|
||||
for(uint64_t i = 0; i < vec_count; i++) {
|
||||
const auto &record = record_array_[i];
|
||||
|
||||
//any attributes?
|
||||
if(record.attribute_map.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::string nid = std::to_string(record_ids_[i]);
|
||||
std::string attrib_str;
|
||||
AttributeSerializer::Encode(record.attribute_map, attrib_str);
|
||||
IVecIdMapper::GetInstance()->Put(nid, attrib_str, table_name_);
|
||||
}
|
||||
|
||||
rc.Record("persist vector attributes");
|
||||
|
||||
} catch (std::exception& ex) {
|
||||
error_code_ = SERVER_UNEXPECTED_ERROR;
|
||||
@ -293,9 +330,9 @@ ServerError SearchVectorTask::OnExecute() {
|
||||
return error_code_;
|
||||
}
|
||||
|
||||
engine::meta::TableSchema table_schema;
|
||||
table_schema.table_id = table_name_;
|
||||
engine::Status stat = DB()->DescribeTable(table_schema);
|
||||
engine::meta::TableSchema table_info;
|
||||
table_info.table_id = table_name_;
|
||||
engine::Status stat = DB()->DescribeTable(table_info);
|
||||
if(!stat.ok()) {
|
||||
error_code_ = SERVER_GROUP_NOT_EXIST;
|
||||
error_msg_ = "Engine failed: " + stat.ToString();
|
||||
@ -305,7 +342,7 @@ ServerError SearchVectorTask::OnExecute() {
|
||||
|
||||
std::vector<float> vec_f;
|
||||
uint64_t record_count = (uint64_t)record_array_.size();
|
||||
vec_f.resize(record_count*table_schema.dimension);
|
||||
vec_f.resize(record_count*table_info.dimension);
|
||||
|
||||
for(uint64_t i = 0; i < record_array_.size(); i++) {
|
||||
const auto& record = record_array_[i];
|
||||
@ -317,14 +354,15 @@ ServerError SearchVectorTask::OnExecute() {
|
||||
}
|
||||
|
||||
uint64_t vec_dim = record.vector_map.begin()->second.size() / sizeof(double);//how many double value?
|
||||
if (vec_dim != table_schema.dimension) {
|
||||
if (vec_dim != table_info.dimension) {
|
||||
SERVER_LOG_ERROR << "Invalid vector dimension: " << vec_dim
|
||||
<< " vs. group dimension:" << table_schema.dimension;
|
||||
<< " vs. group dimension:" << table_info.dimension;
|
||||
error_code_ = SERVER_INVALID_VECTOR_DIMENSION;
|
||||
error_msg_ = "Engine failed: " + stat.ToString();
|
||||
return error_code_;
|
||||
}
|
||||
|
||||
//convert double array to float array(thrift has no float type)
|
||||
const double* d_p = reinterpret_cast<const double*>(record.vector_map.begin()->second.data());
|
||||
for(uint64_t d = 0; d < vec_dim; d++) {
|
||||
vec_f[i*vec_dim + d] = (float)(d_p[d]);
|
||||
@ -336,25 +374,50 @@ ServerError SearchVectorTask::OnExecute() {
|
||||
std::vector<DB_DATE> dates;
|
||||
engine::QueryResults results;
|
||||
stat = DB()->Query(table_name_, (size_t)top_k_, record_count, vec_f.data(), dates, results);
|
||||
rc.Record("search vectors from engine");
|
||||
if(!stat.ok()) {
|
||||
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
|
||||
return SERVER_UNEXPECTED_ERROR;
|
||||
} else {
|
||||
rc.Record("do searching");
|
||||
for(engine::QueryResult& result : results){
|
||||
thrift::TopKQueryResult thrift_topk_result;
|
||||
for(auto id : result) {
|
||||
thrift::QueryResult thrift_result;
|
||||
thrift_result.__set_id(id);
|
||||
thrift_topk_result.query_result_arrays.emplace_back(thrift_result);
|
||||
}
|
||||
|
||||
result_array_.emplace_back(thrift_topk_result);
|
||||
}
|
||||
rc.Record("construct result");
|
||||
}
|
||||
|
||||
rc.Record("done");
|
||||
if(results.size() != record_count) {
|
||||
SERVER_LOG_ERROR << "Search result not returned";
|
||||
return SERVER_UNEXPECTED_ERROR;
|
||||
}
|
||||
|
||||
//construct result array
|
||||
for(uint64_t i = 0; i < record_count; i++) {
|
||||
auto& result = results[i];
|
||||
const auto& record = record_array_[i];
|
||||
|
||||
thrift::TopKQueryResult thrift_topk_result;
|
||||
for(auto id : result) {
|
||||
thrift::QueryResult thrift_result;
|
||||
thrift_result.__set_id(id);
|
||||
|
||||
//need get attributes?
|
||||
if(record.selected_column_array.empty()) {
|
||||
thrift_topk_result.query_result_arrays.emplace_back(thrift_result);
|
||||
continue;
|
||||
}
|
||||
|
||||
std::string nid = std::to_string(id);
|
||||
std::string attrib_str;
|
||||
IVecIdMapper::GetInstance()->Get(nid, attrib_str, table_name_);
|
||||
|
||||
AttribMap attrib_map;
|
||||
AttributeSerializer::Decode(attrib_str, attrib_map);
|
||||
|
||||
for(auto& attri : record.selected_column_array) {
|
||||
thrift_result.column_map[attri] = attrib_map[attri];
|
||||
}
|
||||
|
||||
thrift_topk_result.query_result_arrays.emplace_back(thrift_result);
|
||||
}
|
||||
|
||||
result_array_.emplace_back(thrift_topk_result);
|
||||
}
|
||||
rc.Record("construct result");
|
||||
|
||||
} catch (std::exception& ex) {
|
||||
error_code_ = SERVER_UNEXPECTED_ERROR;
|
||||
|
@ -65,6 +65,20 @@ private:
|
||||
std::string table_name_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
class ShowTablesTask : public BaseTask {
|
||||
public:
|
||||
static BaseTaskPtr Create(std::vector<std::string>& tables);
|
||||
|
||||
protected:
|
||||
ShowTablesTask(std::vector<std::string>& tables);
|
||||
|
||||
ServerError OnExecute() override;
|
||||
|
||||
private:
|
||||
std::vector<std::string>& tables_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
class AddVectorTask : public BaseTask {
|
||||
public:
|
||||
|
@ -108,6 +108,19 @@ bool RocksIdMapper::IsGroupExist(const std::string& group) const {
|
||||
return IsGroupExistInternal(group);
|
||||
}
|
||||
|
||||
ServerError RocksIdMapper::AllGroups(std::vector<std::string>& groups) const {
|
||||
groups.clear();
|
||||
|
||||
std::lock_guard<std::mutex> lck(db_mutex_);
|
||||
for(auto& pair : column_handles_) {
|
||||
if(pair.first == ROCKSDB_DEFAULT_GROUP) {
|
||||
continue;
|
||||
}
|
||||
groups.push_back(pair.first);
|
||||
}
|
||||
|
||||
return SERVER_SUCCESS;
|
||||
}
|
||||
|
||||
ServerError RocksIdMapper::Put(const std::string& nid, const std::string& sid, const std::string& group) {
|
||||
std::lock_guard<std::mutex> lck(db_mutex_);
|
||||
|
@ -26,6 +26,7 @@ public:
|
||||
|
||||
ServerError AddGroup(const std::string& group) override;
|
||||
bool IsGroupExist(const std::string& group) const override;
|
||||
ServerError AllGroups(std::vector<std::string>& groups) const override;
|
||||
|
||||
ServerError Put(const std::string& nid, const std::string& sid, const std::string& group = "") override;
|
||||
ServerError Put(const std::vector<std::string>& nid, const std::vector<std::string>& sid, const std::string& group = "") override;
|
||||
|
@ -50,6 +50,16 @@ bool SimpleIdMapper::IsGroupExist(const std::string& group) const {
|
||||
return id_groups_.count(group) > 0;
|
||||
}
|
||||
|
||||
ServerError SimpleIdMapper::AllGroups(std::vector<std::string>& groups) const {
|
||||
groups.clear();
|
||||
|
||||
for(auto& pair : id_groups_) {
|
||||
groups.push_back(pair.first);
|
||||
}
|
||||
|
||||
return SERVER_SUCCESS;
|
||||
}
|
||||
|
||||
//not thread-safe
|
||||
ServerError SimpleIdMapper::Put(const std::string& nid, const std::string& sid, const std::string& group) {
|
||||
ID_MAPPING& mapping = id_groups_[group];
|
||||
|
@ -27,6 +27,7 @@ public:
|
||||
|
||||
virtual ServerError AddGroup(const std::string& group) = 0;
|
||||
virtual bool IsGroupExist(const std::string& group) const = 0;
|
||||
virtual ServerError AllGroups(std::vector<std::string>& groups) const = 0;
|
||||
|
||||
virtual ServerError Put(const std::string& nid, const std::string& sid, const std::string& group = "") = 0;
|
||||
virtual ServerError Put(const std::vector<std::string>& nid, const std::vector<std::string>& sid, const std::string& group = "") = 0;
|
||||
@ -46,6 +47,7 @@ public:
|
||||
|
||||
ServerError AddGroup(const std::string& group) override;
|
||||
bool IsGroupExist(const std::string& group) const override;
|
||||
ServerError AllGroups(std::vector<std::string>& groups) const override;
|
||||
|
||||
ServerError Put(const std::string& nid, const std::string& sid, const std::string& group = "") override;
|
||||
ServerError Put(const std::vector<std::string>& nid, const std::vector<std::string>& sid, const std::string& group = "") override;
|
||||
|
Loading…
Reference in New Issue
Block a user