mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-11-30 19:08:30 +08:00
Merge branch 'branch-0.3.1' into 'branch-0.3.1'
MS-266 Improve topk reduce time by using multi-threads See merge request megasearch/milvus!271 Former-commit-id: c5b0ce9ebbd2a60fe7797b38996a70293c577789
This commit is contained in:
commit
4a6b1779f6
@ -8,6 +8,8 @@ db_config:
|
||||
db_path: @MILVUS_DB_PATH@ # milvus data storage path
|
||||
db_slave_path: # secondry data storage path, split by semicolon
|
||||
|
||||
parallel_reduce: true # use multi-threads to reduce topk result
|
||||
|
||||
# URI format: dialect://username:password@host:port/database
|
||||
# All parts except dialect are optional, but you MUST include the delimiters
|
||||
# Currently dialect supports mysql or sqlite
|
||||
|
1
cpp/src/cache/CacheMgr.cpp
vendored
1
cpp/src/cache/CacheMgr.cpp
vendored
@ -49,7 +49,6 @@ DataObjPtr CacheMgr::GetItem(const std::string& key) {
|
||||
engine::Index_ptr CacheMgr::GetIndex(const std::string& key) {
|
||||
DataObjPtr obj = GetItem(key);
|
||||
if(obj != nullptr) {
|
||||
SERVER_LOG_ERROR << "Can't get object from key: " << key;
|
||||
return obj->data();
|
||||
}
|
||||
|
||||
|
@ -89,7 +89,7 @@ DBImpl::DBImpl(const Options& options)
|
||||
meta_ptr_ = DBMetaImplFactory::Build(options.meta, options.mode);
|
||||
mem_mgr_ = MemManagerFactory::Build(meta_ptr_, options_);
|
||||
if (options.mode != Options::MODE::READ_ONLY) {
|
||||
ENGINE_LOG_INFO << "StartTimerTasks";
|
||||
ENGINE_LOG_TRACE << "StartTimerTasks";
|
||||
StartTimerTasks();
|
||||
}
|
||||
|
||||
@ -297,7 +297,7 @@ void DBImpl::StartMetricTask() {
|
||||
return;
|
||||
}
|
||||
|
||||
ENGINE_LOG_DEBUG << "Start metric task";
|
||||
ENGINE_LOG_TRACE << "Start metric task";
|
||||
|
||||
server::Metrics::GetInstance().KeepingAliveCounterIncrement(METRIC_ACTION_INTERVAL);
|
||||
int64_t cache_usage = cache::CpuCacheMgr::GetInstance()->CacheUsage();
|
||||
@ -312,7 +312,7 @@ void DBImpl::StartMetricTask() {
|
||||
server::Metrics::GetInstance().GPUMemoryUsageGaugeSet();
|
||||
server::Metrics::GetInstance().OctetsSet();
|
||||
|
||||
ENGINE_LOG_DEBUG << "Metric task finished";
|
||||
ENGINE_LOG_TRACE << "Metric task finished";
|
||||
}
|
||||
|
||||
void DBImpl::StartCompactionTask() {
|
||||
@ -322,8 +322,6 @@ void DBImpl::StartCompactionTask() {
|
||||
return;
|
||||
}
|
||||
|
||||
ENGINE_LOG_DEBUG << "Serialize insert cache";
|
||||
|
||||
//serialize memory data
|
||||
std::set<std::string> temp_table_ids;
|
||||
mem_mgr_->Serialize(temp_table_ids);
|
||||
@ -331,7 +329,9 @@ void DBImpl::StartCompactionTask() {
|
||||
compact_table_ids_.insert(id);
|
||||
}
|
||||
|
||||
ENGINE_LOG_DEBUG << "Insert cache serialized";
|
||||
if(!temp_table_ids.empty()) {
|
||||
SERVER_LOG_DEBUG << "Insert cache serialized";
|
||||
}
|
||||
|
||||
//compactiong has been finished?
|
||||
if(!compact_thread_results_.empty()) {
|
||||
@ -433,7 +433,7 @@ Status DBImpl::BackgroundMergeFiles(const std::string& table_id) {
|
||||
}
|
||||
|
||||
void DBImpl::BackgroundCompaction(std::set<std::string> table_ids) {
|
||||
ENGINE_LOG_DEBUG << " Background compaction thread start";
|
||||
ENGINE_LOG_TRACE << " Background compaction thread start";
|
||||
|
||||
Status status;
|
||||
for (auto& table_id : table_ids) {
|
||||
@ -452,7 +452,7 @@ void DBImpl::BackgroundCompaction(std::set<std::string> table_ids) {
|
||||
}
|
||||
meta_ptr_->CleanUpFilesWithTTL(ttl);
|
||||
|
||||
ENGINE_LOG_DEBUG << " Background compaction thread exit";
|
||||
ENGINE_LOG_TRACE << " Background compaction thread exit";
|
||||
}
|
||||
|
||||
void DBImpl::StartBuildIndexTask(bool force) {
|
||||
@ -581,7 +581,7 @@ Status DBImpl::BuildIndexByTable(const std::string& table_id) {
|
||||
}
|
||||
|
||||
void DBImpl::BackgroundBuildIndex() {
|
||||
ENGINE_LOG_DEBUG << " Background build index thread start";
|
||||
ENGINE_LOG_TRACE << " Background build index thread start";
|
||||
|
||||
std::unique_lock<std::mutex> lock(build_index_mutex_);
|
||||
meta::TableFilesSchema to_index_files;
|
||||
@ -599,7 +599,7 @@ void DBImpl::BackgroundBuildIndex() {
|
||||
}
|
||||
}
|
||||
|
||||
ENGINE_LOG_DEBUG << " Background build index thread exit";
|
||||
ENGINE_LOG_TRACE << " Background build index thread exit";
|
||||
}
|
||||
|
||||
Status DBImpl::DropAll() {
|
||||
|
@ -5,14 +5,60 @@
|
||||
******************************************************************************/
|
||||
#include "SearchTask.h"
|
||||
#include "metrics/Metrics.h"
|
||||
#include "utils/Log.h"
|
||||
#include "db/Log.h"
|
||||
#include "utils/TimeRecorder.h"
|
||||
|
||||
#include <thread>
|
||||
|
||||
namespace zilliz {
|
||||
namespace milvus {
|
||||
namespace engine {
|
||||
|
||||
namespace {
|
||||
|
||||
static constexpr size_t PARALLEL_REDUCE_THRESHOLD = 10000;
|
||||
static constexpr size_t PARALLEL_REDUCE_BATCH = 1000;
|
||||
|
||||
bool NeedParallelReduce(uint64_t nq, uint64_t topk) {
|
||||
server::ServerConfig &config = server::ServerConfig::GetInstance();
|
||||
server::ConfigNode& db_config = config.GetConfig(server::CONFIG_DB);
|
||||
bool need_parallel = db_config.GetBoolValue(server::CONFIG_DB_PARALLEL_REDUCE, true);
|
||||
if(!need_parallel) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return nq*topk >= PARALLEL_REDUCE_THRESHOLD;
|
||||
}
|
||||
|
||||
void ParallelReduce(std::function<void(size_t, size_t)>& reduce_function, size_t max_index) {
|
||||
size_t reduce_batch = PARALLEL_REDUCE_BATCH;
|
||||
|
||||
auto thread_count = std::thread::hardware_concurrency() - 1; //not all core do this work
|
||||
if(thread_count > 0) {
|
||||
reduce_batch = max_index/thread_count + 1;
|
||||
}
|
||||
ENGINE_LOG_DEBUG << "use " << thread_count <<
|
||||
" thread parallelly do reduce, each thread process " << reduce_batch << " vectors";
|
||||
|
||||
std::vector<std::shared_ptr<std::thread> > thread_array;
|
||||
size_t from_index = 0;
|
||||
while(from_index < max_index) {
|
||||
size_t to_index = from_index + reduce_batch;
|
||||
if(to_index > max_index) {
|
||||
to_index = max_index;
|
||||
}
|
||||
|
||||
auto reduce_thread = std::make_shared<std::thread>(reduce_function, from_index, to_index);
|
||||
thread_array.push_back(reduce_thread);
|
||||
|
||||
from_index = to_index;
|
||||
}
|
||||
|
||||
for(auto& thread_ptr : thread_array) {
|
||||
thread_ptr->join();
|
||||
}
|
||||
}
|
||||
|
||||
void CollectDurationMetrics(int index_type, double total_time) {
|
||||
switch(index_type) {
|
||||
case meta::TableFileSchema::RAW: {
|
||||
@ -32,7 +78,7 @@ void CollectDurationMetrics(int index_type, double total_time) {
|
||||
|
||||
std::string GetMetricType() {
|
||||
server::ServerConfig &config = server::ServerConfig::GetInstance();
|
||||
server::ConfigNode engine_config = config.GetConfig(server::CONFIG_ENGINE);
|
||||
server::ConfigNode& engine_config = config.GetConfig(server::CONFIG_ENGINE);
|
||||
return engine_config.GetValue(server::CONFIG_METRICTYPE, "L2");
|
||||
}
|
||||
|
||||
@ -51,7 +97,7 @@ std::shared_ptr<IScheduleTask> SearchTask::Execute() {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
SERVER_LOG_DEBUG << "Searching in file id:" << index_id_<< " with "
|
||||
ENGINE_LOG_DEBUG << "Searching in file id:" << index_id_<< " with "
|
||||
<< search_contexts_.size() << " tasks";
|
||||
|
||||
server::TimeRecorder rc("DoSearch file id:" + std::to_string(index_id_));
|
||||
@ -79,6 +125,9 @@ std::shared_ptr<IScheduleTask> SearchTask::Execute() {
|
||||
auto spec_k = index_engine_->Count() < context->topk() ? index_engine_->Count() : context->topk();
|
||||
SearchTask::ClusterResult(output_ids, output_distence, context->nq(), spec_k, result_set);
|
||||
|
||||
span = rc.RecordSection("cluster result for context:" + context->Identity());
|
||||
context->AccumReduceCost(span);
|
||||
|
||||
//step 4: pick up topk result
|
||||
SearchTask::TopkResult(result_set, inner_k, metric_l2, context->GetResult());
|
||||
|
||||
@ -86,7 +135,7 @@ std::shared_ptr<IScheduleTask> SearchTask::Execute() {
|
||||
context->AccumReduceCost(span);
|
||||
|
||||
} catch (std::exception& ex) {
|
||||
SERVER_LOG_ERROR << "SearchTask encounter exception: " << ex.what();
|
||||
ENGINE_LOG_ERROR << "SearchTask encounter exception: " << ex.what();
|
||||
context->IndexSearchDone(index_id_);//mark as done avoid dead lock, even search failed
|
||||
continue;
|
||||
}
|
||||
@ -112,23 +161,32 @@ Status SearchTask::ClusterResult(const std::vector<long> &output_ids,
|
||||
if(output_ids.size() < nq*topk || output_distence.size() < nq*topk) {
|
||||
std::string msg = "Invalid id array size: " + std::to_string(output_ids.size()) +
|
||||
" distance array size: " + std::to_string(output_distence.size());
|
||||
SERVER_LOG_ERROR << msg;
|
||||
ENGINE_LOG_ERROR << msg;
|
||||
return Status::Error(msg);
|
||||
}
|
||||
|
||||
result_set.clear();
|
||||
result_set.reserve(nq);
|
||||
for (auto i = 0; i < nq; i++) {
|
||||
SearchContext::Id2DistanceMap id_distance;
|
||||
id_distance.reserve(topk);
|
||||
for (auto k = 0; k < topk; k++) {
|
||||
uint64_t index = i * topk + k;
|
||||
if(output_ids[index] < 0) {
|
||||
continue;
|
||||
result_set.resize(nq);
|
||||
|
||||
std::function<void(size_t, size_t)> reduce_worker = [&](size_t from_index, size_t to_index) {
|
||||
for (auto i = from_index; i < to_index; i++) {
|
||||
SearchContext::Id2DistanceMap id_distance;
|
||||
id_distance.reserve(topk);
|
||||
for (auto k = 0; k < topk; k++) {
|
||||
uint64_t index = i * topk + k;
|
||||
if(output_ids[index] < 0) {
|
||||
continue;
|
||||
}
|
||||
id_distance.push_back(std::make_pair(output_ids[index], output_distence[index]));
|
||||
}
|
||||
id_distance.push_back(std::make_pair(output_ids[index], output_distence[index]));
|
||||
result_set[i] = id_distance;
|
||||
}
|
||||
result_set.emplace_back(id_distance);
|
||||
};
|
||||
|
||||
if(NeedParallelReduce(nq, topk)) {
|
||||
ParallelReduce(reduce_worker, nq);
|
||||
} else {
|
||||
reduce_worker(0, nq);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
@ -140,7 +198,7 @@ Status SearchTask::MergeResult(SearchContext::Id2DistanceMap &distance_src,
|
||||
bool ascending) {
|
||||
//Note: the score_src and score_target are already arranged by score in ascending order
|
||||
if(distance_src.empty()) {
|
||||
SERVER_LOG_WARNING << "Empty distance source array";
|
||||
ENGINE_LOG_WARNING << "Empty distance source array";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -218,14 +276,22 @@ Status SearchTask::TopkResult(SearchContext::ResultSet &result_src,
|
||||
|
||||
if (result_src.size() != result_target.size()) {
|
||||
std::string msg = "Invalid result set size";
|
||||
SERVER_LOG_ERROR << msg;
|
||||
ENGINE_LOG_ERROR << msg;
|
||||
return Status::Error(msg);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < result_src.size(); i++) {
|
||||
SearchContext::Id2DistanceMap &score_src = result_src[i];
|
||||
SearchContext::Id2DistanceMap &score_target = result_target[i];
|
||||
SearchTask::MergeResult(score_src, score_target, topk, ascending);
|
||||
std::function<void(size_t, size_t)> ReduceWorker = [&](size_t from_index, size_t to_index) {
|
||||
for (size_t i = from_index; i < to_index; i++) {
|
||||
SearchContext::Id2DistanceMap &score_src = result_src[i];
|
||||
SearchContext::Id2DistanceMap &score_target = result_target[i];
|
||||
SearchTask::MergeResult(score_src, score_target, topk, ascending);
|
||||
}
|
||||
};
|
||||
|
||||
if(NeedParallelReduce(result_src.size(), topk)) {
|
||||
ParallelReduce(ReduceWorker, result_src.size());
|
||||
} else {
|
||||
ReduceWorker(0, result_src.size());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
@ -29,6 +29,7 @@ static const std::string CONFIG_DB_INDEX_TRIGGER_SIZE = "index_building_threshol
|
||||
static const std::string CONFIG_DB_ARCHIVE_DISK = "archive_disk_threshold";
|
||||
static const std::string CONFIG_DB_ARCHIVE_DAYS = "archive_days_threshold";
|
||||
static const std::string CONFIG_DB_INSERT_BUFFER_SIZE = "insert_buffer_size";
|
||||
static const std::string CONFIG_DB_PARALLEL_REDUCE = "parallel_reduce";
|
||||
|
||||
static const std::string CONFIG_LOG = "log_config";
|
||||
|
||||
|
@ -6,6 +6,8 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "db/scheduler/task/SearchTask.h"
|
||||
#include "utils/TimeRecorder.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
@ -17,27 +19,33 @@ static constexpr uint64_t NQ = 15;
|
||||
static constexpr uint64_t TOP_K = 64;
|
||||
|
||||
void BuildResult(uint64_t nq,
|
||||
uint64_t top_k,
|
||||
uint64_t topk,
|
||||
bool ascending,
|
||||
std::vector<long> &output_ids,
|
||||
std::vector<float> &output_distence) {
|
||||
output_ids.clear();
|
||||
output_ids.resize(nq*top_k);
|
||||
output_ids.resize(nq*topk);
|
||||
output_distence.clear();
|
||||
output_distence.resize(nq*top_k);
|
||||
output_distence.resize(nq*topk);
|
||||
|
||||
for(uint64_t i = 0; i < nq; i++) {
|
||||
for(uint64_t j = 0; j < top_k; j++) {
|
||||
output_ids[i * top_k + j] = (long)(drand48()*100000);
|
||||
output_distence[i * top_k + j] = j + drand48();
|
||||
for(uint64_t j = 0; j < topk; j++) {
|
||||
output_ids[i * topk + j] = (long)(drand48()*100000);
|
||||
output_distence[i * topk + j] = ascending ? (j + drand48()) : ((topk - j) + drand48());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CheckResult(const engine::SearchContext::Id2DistanceMap& src_1,
|
||||
const engine::SearchContext::Id2DistanceMap& src_2,
|
||||
const engine::SearchContext::Id2DistanceMap& target) {
|
||||
const engine::SearchContext::Id2DistanceMap& target,
|
||||
bool ascending) {
|
||||
for(uint64_t i = 0; i < target.size() - 1; i++) {
|
||||
ASSERT_LE(target[i].second, target[i + 1].second);
|
||||
if(ascending) {
|
||||
ASSERT_LE(target[i].second, target[i + 1].second);
|
||||
} else {
|
||||
ASSERT_GE(target[i].second, target[i + 1].second);
|
||||
}
|
||||
}
|
||||
|
||||
using ID2DistMap = std::map<long, float>;
|
||||
@ -57,9 +65,52 @@ void CheckResult(const engine::SearchContext::Id2DistanceMap& src_1,
|
||||
}
|
||||
}
|
||||
|
||||
void CheckCluster(const std::vector<long>& target_ids,
|
||||
const std::vector<float>& target_distence,
|
||||
const engine::SearchContext::ResultSet& src_result,
|
||||
int64_t nq,
|
||||
int64_t topk) {
|
||||
ASSERT_EQ(src_result.size(), nq);
|
||||
for(int64_t i = 0; i < nq; i++) {
|
||||
auto& res = src_result[i];
|
||||
ASSERT_EQ(res.size(), topk);
|
||||
|
||||
if(res.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
ASSERT_EQ(res[0].first, target_ids[i*topk]);
|
||||
ASSERT_EQ(res[topk - 1].first, target_ids[i*topk + topk - 1]);
|
||||
}
|
||||
}
|
||||
|
||||
void CheckTopkResult(const engine::SearchContext::ResultSet& src_result,
|
||||
bool ascending,
|
||||
int64_t nq,
|
||||
int64_t topk) {
|
||||
ASSERT_EQ(src_result.size(), nq);
|
||||
for(int64_t i = 0; i < nq; i++) {
|
||||
auto& res = src_result[i];
|
||||
ASSERT_EQ(res.size(), topk);
|
||||
|
||||
if(res.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for(int64_t k = 0; k < topk - 1; k++) {
|
||||
if(ascending) {
|
||||
ASSERT_LE(res[k].second, res[k + 1].second);
|
||||
} else {
|
||||
ASSERT_GE(res[k].second, res[k + 1].second);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
TEST(DBSearchTest, TOPK_TEST) {
|
||||
bool ascending = true;
|
||||
std::vector<long> target_ids;
|
||||
std::vector<float> target_distence;
|
||||
engine::SearchContext::ResultSet src_result;
|
||||
@ -67,19 +118,19 @@ TEST(DBSearchTest, TOPK_TEST) {
|
||||
ASSERT_FALSE(status.ok());
|
||||
ASSERT_TRUE(src_result.empty());
|
||||
|
||||
BuildResult(NQ, TOP_K, target_ids, target_distence);
|
||||
BuildResult(NQ, TOP_K, ascending, target_ids, target_distence);
|
||||
status = engine::SearchTask::ClusterResult(target_ids, target_distence, NQ, TOP_K, src_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
ASSERT_EQ(src_result.size(), NQ);
|
||||
|
||||
engine::SearchContext::ResultSet target_result;
|
||||
status = engine::SearchTask::TopkResult(target_result, TOP_K, true, target_result);
|
||||
status = engine::SearchTask::TopkResult(target_result, TOP_K, ascending, target_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
|
||||
status = engine::SearchTask::TopkResult(target_result, TOP_K, true, src_result);
|
||||
status = engine::SearchTask::TopkResult(target_result, TOP_K, ascending, src_result);
|
||||
ASSERT_FALSE(status.ok());
|
||||
|
||||
status = engine::SearchTask::TopkResult(src_result, TOP_K, true, target_result);
|
||||
status = engine::SearchTask::TopkResult(src_result, TOP_K, ascending, target_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
ASSERT_TRUE(src_result.empty());
|
||||
ASSERT_EQ(target_result.size(), NQ);
|
||||
@ -87,21 +138,21 @@ TEST(DBSearchTest, TOPK_TEST) {
|
||||
std::vector<long> src_ids;
|
||||
std::vector<float> src_distence;
|
||||
uint64_t wrong_topk = TOP_K - 10;
|
||||
BuildResult(NQ, wrong_topk, src_ids, src_distence);
|
||||
BuildResult(NQ, wrong_topk, ascending, src_ids, src_distence);
|
||||
|
||||
status = engine::SearchTask::ClusterResult(src_ids, src_distence, NQ, wrong_topk, src_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
|
||||
status = engine::SearchTask::TopkResult(src_result, TOP_K, true, target_result);
|
||||
status = engine::SearchTask::TopkResult(src_result, TOP_K, ascending, target_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
for(uint64_t i = 0; i < NQ; i++) {
|
||||
ASSERT_EQ(target_result[i].size(), TOP_K);
|
||||
}
|
||||
|
||||
wrong_topk = TOP_K + 10;
|
||||
BuildResult(NQ, wrong_topk, src_ids, src_distence);
|
||||
BuildResult(NQ, wrong_topk, ascending, src_ids, src_distence);
|
||||
|
||||
status = engine::SearchTask::TopkResult(src_result, TOP_K, true, target_result);
|
||||
status = engine::SearchTask::TopkResult(src_result, TOP_K, ascending, target_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
for(uint64_t i = 0; i < NQ; i++) {
|
||||
ASSERT_EQ(target_result[i].size(), TOP_K);
|
||||
@ -109,6 +160,7 @@ TEST(DBSearchTest, TOPK_TEST) {
|
||||
}
|
||||
|
||||
TEST(DBSearchTest, MERGE_TEST) {
|
||||
bool ascending = true;
|
||||
std::vector<long> target_ids;
|
||||
std::vector<float> target_distence;
|
||||
std::vector<long> src_ids;
|
||||
@ -116,8 +168,8 @@ TEST(DBSearchTest, MERGE_TEST) {
|
||||
engine::SearchContext::ResultSet src_result, target_result;
|
||||
|
||||
uint64_t src_count = 5, target_count = 8;
|
||||
BuildResult(1, src_count, src_ids, src_distence);
|
||||
BuildResult(1, target_count, target_ids, target_distence);
|
||||
BuildResult(1, src_count, ascending, src_ids, src_distence);
|
||||
BuildResult(1, target_count, ascending, target_ids, target_distence);
|
||||
auto status = engine::SearchTask::ClusterResult(src_ids, src_distence, 1, src_count, src_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
status = engine::SearchTask::ClusterResult(target_ids, target_distence, 1, target_count, target_result);
|
||||
@ -126,37 +178,107 @@ TEST(DBSearchTest, MERGE_TEST) {
|
||||
{
|
||||
engine::SearchContext::Id2DistanceMap src = src_result[0];
|
||||
engine::SearchContext::Id2DistanceMap target = target_result[0];
|
||||
status = engine::SearchTask::MergeResult(src, target, 10, true);
|
||||
status = engine::SearchTask::MergeResult(src, target, 10, ascending);
|
||||
ASSERT_TRUE(status.ok());
|
||||
ASSERT_EQ(target.size(), 10);
|
||||
CheckResult(src_result[0], target_result[0], target);
|
||||
CheckResult(src_result[0], target_result[0], target, ascending);
|
||||
}
|
||||
|
||||
{
|
||||
engine::SearchContext::Id2DistanceMap src = src_result[0];
|
||||
engine::SearchContext::Id2DistanceMap target;
|
||||
status = engine::SearchTask::MergeResult(src, target, 10, true);
|
||||
status = engine::SearchTask::MergeResult(src, target, 10, ascending);
|
||||
ASSERT_TRUE(status.ok());
|
||||
ASSERT_EQ(target.size(), src_count);
|
||||
ASSERT_TRUE(src.empty());
|
||||
CheckResult(src_result[0], target_result[0], target);
|
||||
CheckResult(src_result[0], target_result[0], target, ascending);
|
||||
}
|
||||
|
||||
{
|
||||
engine::SearchContext::Id2DistanceMap src = src_result[0];
|
||||
engine::SearchContext::Id2DistanceMap target = target_result[0];
|
||||
status = engine::SearchTask::MergeResult(src, target, 30, true);
|
||||
status = engine::SearchTask::MergeResult(src, target, 30, ascending);
|
||||
ASSERT_TRUE(status.ok());
|
||||
ASSERT_EQ(target.size(), src_count + target_count);
|
||||
CheckResult(src_result[0], target_result[0], target);
|
||||
CheckResult(src_result[0], target_result[0], target, ascending);
|
||||
}
|
||||
|
||||
{
|
||||
engine::SearchContext::Id2DistanceMap target = src_result[0];
|
||||
engine::SearchContext::Id2DistanceMap src = target_result[0];
|
||||
status = engine::SearchTask::MergeResult(src, target, 30, true);
|
||||
status = engine::SearchTask::MergeResult(src, target, 30, ascending);
|
||||
ASSERT_TRUE(status.ok());
|
||||
ASSERT_EQ(target.size(), src_count + target_count);
|
||||
CheckResult(src_result[0], target_result[0], target);
|
||||
CheckResult(src_result[0], target_result[0], target, ascending);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(DBSearchTest, PARALLEL_CLUSTER_TEST) {
|
||||
bool ascending = true;
|
||||
std::vector<long> target_ids;
|
||||
std::vector<float> target_distence;
|
||||
engine::SearchContext::ResultSet src_result;
|
||||
|
||||
auto DoCluster = [&](int64_t nq, int64_t topk) {
|
||||
server::TimeRecorder rc("DoCluster");
|
||||
src_result.clear();
|
||||
BuildResult(nq, topk, ascending, target_ids, target_distence);
|
||||
rc.RecordSection("build id/dietance map");
|
||||
|
||||
auto status = engine::SearchTask::ClusterResult(target_ids, target_distence, nq, topk, src_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
ASSERT_EQ(src_result.size(), nq);
|
||||
|
||||
rc.RecordSection("cluster result");
|
||||
|
||||
CheckCluster(target_ids, target_distence, src_result, nq, topk);
|
||||
rc.RecordSection("check result");
|
||||
};
|
||||
|
||||
DoCluster(10000, 1000);
|
||||
DoCluster(333, 999);
|
||||
DoCluster(1, 1000);
|
||||
DoCluster(1, 1);
|
||||
DoCluster(7, 0);
|
||||
DoCluster(9999, 1);
|
||||
DoCluster(10001, 1);
|
||||
DoCluster(58273, 1234);
|
||||
}
|
||||
|
||||
TEST(DBSearchTest, PARALLEL_TOPK_TEST) {
|
||||
std::vector<long> target_ids;
|
||||
std::vector<float> target_distence;
|
||||
engine::SearchContext::ResultSet src_result;
|
||||
|
||||
std::vector<long> insufficient_ids;
|
||||
std::vector<float> insufficient_distence;
|
||||
engine::SearchContext::ResultSet insufficient_result;
|
||||
|
||||
auto DoTopk = [&](int64_t nq, int64_t topk,int64_t insufficient_topk, bool ascending) {
|
||||
src_result.clear();
|
||||
insufficient_result.clear();
|
||||
|
||||
server::TimeRecorder rc("DoCluster");
|
||||
|
||||
BuildResult(nq, topk, ascending, target_ids, target_distence);
|
||||
auto status = engine::SearchTask::ClusterResult(target_ids, target_distence, nq, topk, src_result);
|
||||
rc.RecordSection("cluster result");
|
||||
|
||||
BuildResult(nq, insufficient_topk, ascending, insufficient_ids, insufficient_distence);
|
||||
status = engine::SearchTask::ClusterResult(target_ids, target_distence, nq, insufficient_topk, insufficient_result);
|
||||
rc.RecordSection("cluster result");
|
||||
|
||||
engine::SearchTask::TopkResult(insufficient_result, topk, ascending, src_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
rc.RecordSection("topk");
|
||||
|
||||
CheckTopkResult(src_result, ascending, nq, topk);
|
||||
rc.RecordSection("check result");
|
||||
};
|
||||
|
||||
DoTopk(5, 10, 4, false);
|
||||
DoTopk(20005, 998, 123, true);
|
||||
DoTopk(9987, 12, 10, false);
|
||||
DoTopk(77777, 1000, 1, false);
|
||||
DoTopk(5432, 8899, 8899, true);
|
||||
}
|
Loading…
Reference in New Issue
Block a user