Merge remote-tracking branch 'upstream/branch-0.3.1' into branch-0.3.1

Former-commit-id: 4f3226d234fb9cbbab86cb7e85d65e97c5773e2c
This commit is contained in:
zhiru 2019-07-24 20:34:33 +08:00
commit cfffcf9e9a
29 changed files with 540 additions and 219 deletions

View File

@ -18,3 +18,4 @@ Please mark all change in change log and use the ticket from JIRA.
- MS-161 - Add CI / CD Module to Milvus Project - MS-161 - Add CI / CD Module to Milvus Project
- MS-202 - Add Milvus Jenkins project email notification - MS-202 - Add Milvus Jenkins project email notification
- MS-215 - Add Milvus cluster CI/CD groovy file - MS-215 - Add Milvus cluster CI/CD groovy file
- MS-277 - Update CUDA Version to V10.1

View File

@ -35,7 +35,7 @@ pipeline {
defaultContainer 'jnlp' defaultContainer 'jnlp'
containerTemplate { containerTemplate {
name 'milvus-build-env' name 'milvus-build-env'
image 'registry.zilliz.com/milvus/milvus-build-env:v0.11' image 'registry.zilliz.com/milvus/milvus-build-env:v0.12'
ttyEnabled true ttyEnabled true
command 'cat' command 'cat'
} }

View File

@ -35,7 +35,7 @@ pipeline {
defaultContainer 'jnlp' defaultContainer 'jnlp'
containerTemplate { containerTemplate {
name 'milvus-build-env' name 'milvus-build-env'
image 'registry.zilliz.com/milvus/milvus-build-env:v0.11' image 'registry.zilliz.com/milvus/milvus-build-env:v0.12'
ttyEnabled true ttyEnabled true
command 'cat' command 'cat'
} }

View File

@ -35,7 +35,7 @@ pipeline {
defaultContainer 'jnlp' defaultContainer 'jnlp'
containerTemplate { containerTemplate {
name 'milvus-build-env' name 'milvus-build-env'
image 'registry.zilliz.com/milvus/milvus-build-env:v0.11' image 'registry.zilliz.com/milvus/milvus-build-env:v0.12'
ttyEnabled true ttyEnabled true
command 'cat' command 'cat'
} }

View File

@ -752,10 +752,7 @@ macro(build_faiss)
if(${MILVUS_WITH_FAISS_GPU_VERSION} STREQUAL "ON") if(${MILVUS_WITH_FAISS_GPU_VERSION} STREQUAL "ON")
set(FAISS_CONFIGURE_ARGS ${FAISS_CONFIGURE_ARGS} set(FAISS_CONFIGURE_ARGS ${FAISS_CONFIGURE_ARGS}
"--with-cuda=${CUDA_TOOLKIT_ROOT_DIR}" "--with-cuda=${CUDA_TOOLKIT_ROOT_DIR}"
"--with-cuda-arch=\"-gencode=arch=compute_35,code=sm_35\"" "--with-cuda-arch=-gencode=arch=compute_35,code=sm_35 -gencode=arch=compute_52,code=sm_52 -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_61,code=sm_61 -gencode=arch=compute_75,code=sm_75"
"--with-cuda-arch=\"-gencode=arch=compute_52,code=sm_52\""
"--with-cuda-arch=\"-gencode=arch=compute_60,code=sm_60\""
"--with-cuda-arch=\"-gencode=arch=compute_61,code=sm_61\""
) )
else() else()
set(FAISS_CONFIGURE_ARGS ${FAISS_CONFIGURE_ARGS} --without-cuda) set(FAISS_CONFIGURE_ARGS ${FAISS_CONFIGURE_ARGS} --without-cuda)
@ -769,7 +766,7 @@ macro(build_faiss)
"./configure" "./configure"
${FAISS_CONFIGURE_ARGS} ${FAISS_CONFIGURE_ARGS}
BUILD_COMMAND BUILD_COMMAND
${MAKE} ${MAKE_BUILD_ARGS} ${MAKE} ${MAKE_BUILD_ARGS} VERBOSE=1
BUILD_IN_SOURCE BUILD_IN_SOURCE
1 1
INSTALL_COMMAND INSTALL_COMMAND
@ -1676,14 +1673,18 @@ macro(build_gperftools)
BUILD_BYPRODUCTS BUILD_BYPRODUCTS
${GPERFTOOLS_STATIC_LIB}) ${GPERFTOOLS_STATIC_LIB})
ExternalProject_Add_StepDependencies(gperftools_ep build libunwind_ep)
file(MAKE_DIRECTORY "${GPERFTOOLS_INCLUDE_DIR}") file(MAKE_DIRECTORY "${GPERFTOOLS_INCLUDE_DIR}")
add_library(gperftools SHARED IMPORTED) add_library(gperftools STATIC IMPORTED)
set_target_properties(gperftools set_target_properties(gperftools
PROPERTIES IMPORTED_LOCATION "${GPERFTOOLS_STATIC_LIB}" PROPERTIES IMPORTED_LOCATION "${GPERFTOOLS_STATIC_LIB}"
INTERFACE_INCLUDE_DIRECTORIES "${GPERFTOOLS_INCLUDE_DIR}") INTERFACE_INCLUDE_DIRECTORIES "${GPERFTOOLS_INCLUDE_DIR}"
INTERFACE_LINK_LIBRARIES libunwind)
add_dependencies(gperftools gperftools_ep) add_dependencies(gperftools gperftools_ep)
add_dependencies(gperftools libunwind_ep)
endmacro() endmacro()
if(MILVUS_WITH_GPERFTOOLS) if(MILVUS_WITH_GPERFTOOLS)
@ -1692,4 +1693,5 @@ if(MILVUS_WITH_GPERFTOOLS)
# TODO: Don't use global includes but rather target_include_directories # TODO: Don't use global includes but rather target_include_directories
get_target_property(GPERFTOOLS_INCLUDE_DIR gperftools INTERFACE_INCLUDE_DIRECTORIES) get_target_property(GPERFTOOLS_INCLUDE_DIR gperftools INTERFACE_INCLUDE_DIRECTORIES)
include_directories(SYSTEM ${GPERFTOOLS_INCLUDE_DIR}) include_directories(SYSTEM ${GPERFTOOLS_INCLUDE_DIR})
link_directories(SYSTEM ${GPERFTOOLS_PREFIX}/lib)
endif() endif()

View File

@ -8,6 +8,8 @@ db_config:
db_path: @MILVUS_DB_PATH@ # milvus data storage path db_path: @MILVUS_DB_PATH@ # milvus data storage path
db_slave_path: # secondry data storage path, split by semicolon db_slave_path: # secondry data storage path, split by semicolon
parallel_reduce: false # use multi-threads to reduce topk result
# URI format: dialect://username:password@host:port/database # URI format: dialect://username:password@host:port/database
# All parts except dialect are optional, but you MUST include the delimiters # All parts except dialect are optional, but you MUST include the delimiters
# Currently dialect supports mysql or sqlite # Currently dialect supports mysql or sqlite

View File

@ -63,10 +63,6 @@ include_directories("${CUDA_TOOLKIT_ROOT_DIR}/include")
include_directories(thrift/gen-cpp) include_directories(thrift/gen-cpp)
include_directories(/usr/include/mysql) include_directories(/usr/include/mysql)
if (MILVUS_ENABLE_PROFILING STREQUAL "ON")
SET(PROFILER_LIB profiler)
endif()
set(third_party_libs set(third_party_libs
easyloggingpp easyloggingpp
sqlite sqlite
@ -85,7 +81,6 @@ set(third_party_libs
zlib zlib
zstd zstd
mysqlpp mysqlpp
${PROFILER_LIB}
${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs/libnvidia-ml.so ${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs/libnvidia-ml.so
cudart cudart
) )
@ -103,6 +98,12 @@ else()
openblas) openblas)
endif() endif()
if (MILVUS_ENABLE_PROFILING STREQUAL "ON")
set(third_party_libs ${third_party_libs}
gperftools
libunwind)
endif()
if (GPU_VERSION STREQUAL "ON") if (GPU_VERSION STREQUAL "ON")
link_directories("${CUDA_TOOLKIT_ROOT_DIR}/lib64") link_directories("${CUDA_TOOLKIT_ROOT_DIR}/lib64")
set(engine_libs set(engine_libs

View File

@ -89,7 +89,7 @@ void Cache::erase(const std::string& key) {
const DataObjPtr& data_ptr = obj_ptr->data_; const DataObjPtr& data_ptr = obj_ptr->data_;
usage_ -= data_ptr->size(); usage_ -= data_ptr->size();
SERVER_LOG_DEBUG << "Erase " << key << " from cache"; SERVER_LOG_DEBUG << "Erase " << key << " size: " << data_ptr->size();
lru_.erase(key); lru_.erase(key);
} }

View File

@ -4,6 +4,7 @@
// Proprietary and confidential. // Proprietary and confidential.
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
#include "utils/Log.h"
#include "CacheMgr.h" #include "CacheMgr.h"
#include "metrics/Metrics.h" #include "metrics/Metrics.h"
@ -20,6 +21,7 @@ CacheMgr::~CacheMgr() {
uint64_t CacheMgr::ItemCount() const { uint64_t CacheMgr::ItemCount() const {
if(cache_ == nullptr) { if(cache_ == nullptr) {
SERVER_LOG_ERROR << "Cache doesn't exist";
return 0; return 0;
} }
@ -28,6 +30,7 @@ uint64_t CacheMgr::ItemCount() const {
bool CacheMgr::ItemExists(const std::string& key) { bool CacheMgr::ItemExists(const std::string& key) {
if(cache_ == nullptr) { if(cache_ == nullptr) {
SERVER_LOG_ERROR << "Cache doesn't exist";
return false; return false;
} }
@ -36,6 +39,7 @@ bool CacheMgr::ItemExists(const std::string& key) {
DataObjPtr CacheMgr::GetItem(const std::string& key) { DataObjPtr CacheMgr::GetItem(const std::string& key) {
if(cache_ == nullptr) { if(cache_ == nullptr) {
SERVER_LOG_ERROR << "Cache doesn't exist";
return nullptr; return nullptr;
} }
server::Metrics::GetInstance().CacheAccessTotalIncrement(); server::Metrics::GetInstance().CacheAccessTotalIncrement();
@ -53,6 +57,7 @@ engine::Index_ptr CacheMgr::GetIndex(const std::string& key) {
void CacheMgr::InsertItem(const std::string& key, const DataObjPtr& data) { void CacheMgr::InsertItem(const std::string& key, const DataObjPtr& data) {
if(cache_ == nullptr) { if(cache_ == nullptr) {
SERVER_LOG_ERROR << "Cache doesn't exist";
return; return;
} }
@ -62,6 +67,7 @@ void CacheMgr::InsertItem(const std::string& key, const DataObjPtr& data) {
void CacheMgr::InsertItem(const std::string& key, const engine::Index_ptr& index) { void CacheMgr::InsertItem(const std::string& key, const engine::Index_ptr& index) {
if(cache_ == nullptr) { if(cache_ == nullptr) {
SERVER_LOG_ERROR << "Cache doesn't exist";
return; return;
} }
@ -72,6 +78,7 @@ void CacheMgr::InsertItem(const std::string& key, const engine::Index_ptr& index
void CacheMgr::EraseItem(const std::string& key) { void CacheMgr::EraseItem(const std::string& key) {
if(cache_ == nullptr) { if(cache_ == nullptr) {
SERVER_LOG_ERROR << "Cache doesn't exist";
return; return;
} }
@ -81,6 +88,7 @@ void CacheMgr::EraseItem(const std::string& key) {
void CacheMgr::PrintInfo() { void CacheMgr::PrintInfo() {
if(cache_ == nullptr) { if(cache_ == nullptr) {
SERVER_LOG_ERROR << "Cache doesn't exist";
return; return;
} }
@ -89,6 +97,7 @@ void CacheMgr::PrintInfo() {
void CacheMgr::ClearCache() { void CacheMgr::ClearCache() {
if(cache_ == nullptr) { if(cache_ == nullptr) {
SERVER_LOG_ERROR << "Cache doesn't exist";
return; return;
} }
@ -97,6 +106,7 @@ void CacheMgr::ClearCache() {
int64_t CacheMgr::CacheUsage() const { int64_t CacheMgr::CacheUsage() const {
if(cache_ == nullptr) { if(cache_ == nullptr) {
SERVER_LOG_ERROR << "Cache doesn't exist";
return 0; return 0;
} }
@ -105,6 +115,7 @@ int64_t CacheMgr::CacheUsage() const {
int64_t CacheMgr::CacheCapacity() const { int64_t CacheMgr::CacheCapacity() const {
if(cache_ == nullptr) { if(cache_ == nullptr) {
SERVER_LOG_ERROR << "Cache doesn't exist";
return 0; return 0;
} }
@ -113,6 +124,7 @@ int64_t CacheMgr::CacheCapacity() const {
void CacheMgr::SetCapacity(int64_t capacity) { void CacheMgr::SetCapacity(int64_t capacity) {
if(cache_ == nullptr) { if(cache_ == nullptr) {
SERVER_LOG_ERROR << "Cache doesn't exist";
return; return;
} }
cache_->set_capacity(capacity); cache_->set_capacity(capacity);

View File

@ -12,10 +12,14 @@ namespace zilliz {
namespace milvus { namespace milvus {
namespace cache { namespace cache {
namespace {
constexpr int64_t unit = 1024 * 1024 * 1024;
}
CpuCacheMgr::CpuCacheMgr() { CpuCacheMgr::CpuCacheMgr() {
server::ConfigNode& config = server::ServerConfig::GetInstance().GetConfig(server::CONFIG_CACHE); server::ConfigNode& config = server::ServerConfig::GetInstance().GetConfig(server::CONFIG_CACHE);
int64_t cap = config.GetInt64Value(server::CONFIG_CPU_CACHE_CAPACITY, 16); int64_t cap = config.GetInt64Value(server::CONFIG_CPU_CACHE_CAPACITY, 16);
cap *= 1024*1024*1024; cap *= unit;
cache_ = std::make_shared<Cache>(cap, 1UL<<32); cache_ = std::make_shared<Cache>(cap, 1UL<<32);
double free_percent = config.GetDoubleValue(server::CACHE_FREE_PERCENT, 0.85); double free_percent = config.GetDoubleValue(server::CACHE_FREE_PERCENT, 0.85);

View File

@ -11,10 +11,14 @@ namespace zilliz {
namespace milvus { namespace milvus {
namespace cache { namespace cache {
namespace {
constexpr int64_t unit = 1024 * 1024 * 1024;
}
GpuCacheMgr::GpuCacheMgr() { GpuCacheMgr::GpuCacheMgr() {
server::ConfigNode& config = server::ServerConfig::GetInstance().GetConfig(server::CONFIG_CACHE); server::ConfigNode& config = server::ServerConfig::GetInstance().GetConfig(server::CONFIG_CACHE);
int64_t cap = config.GetInt64Value(server::CONFIG_GPU_CACHE_CAPACITY, 1); int64_t cap = config.GetInt64Value(server::CONFIG_GPU_CACHE_CAPACITY, 1);
cap *= 1024*1024*1024; cap *= unit;
cache_ = std::make_shared<Cache>(cap, 1UL<<32); cache_ = std::make_shared<Cache>(cap, 1UL<<32);
} }

View File

@ -94,7 +94,7 @@ double
ConfigNode::GetDoubleValue(const std::string &param_key, double default_val) const { ConfigNode::GetDoubleValue(const std::string &param_key, double default_val) const {
std::string val = GetValue(param_key); std::string val = GetValue(param_key);
if (!val.empty()) { if (!val.empty()) {
return std::strtold(val.c_str(), nullptr); return std::strtod(val.c_str(), nullptr);
} else { } else {
return default_val; return default_val;
} }

View File

@ -9,14 +9,14 @@ namespace zilliz {
namespace milvus { namespace milvus {
namespace engine { namespace engine {
const size_t K = 1024UL; constexpr size_t K = 1024UL;
const size_t M = K * K; constexpr size_t M = K * K;
const size_t G = K * M; constexpr size_t G = K * M;
const size_t T = K * G; constexpr size_t T = K * G;
const size_t MAX_TABLE_FILE_MEM = 128 * M; constexpr size_t MAX_TABLE_FILE_MEM = 128 * M;
const int VECTOR_TYPE_SIZE = sizeof(float); constexpr int VECTOR_TYPE_SIZE = sizeof(float);
} // namespace engine } // namespace engine
} // namespace milvus } // namespace milvus

View File

@ -12,11 +12,10 @@ namespace zilliz {
namespace milvus { namespace milvus {
namespace engine { namespace engine {
DB::~DB() {} DB::~DB() = default;
void DB::Open(const Options& options, DB** dbptr) { void DB::Open(const Options& options, DB** dbptr) {
*dbptr = DBFactory::Build(options); *dbptr = DBFactory::Build(options);
return;
} }
} // namespace engine } // namespace engine

View File

@ -52,7 +52,7 @@ public:
DB(const DB&) = delete; DB(const DB&) = delete;
DB& operator=(const DB&) = delete; DB& operator=(const DB&) = delete;
virtual ~DB(); virtual ~DB() = 0;
}; // DB }; // DB
} // namespace engine } // namespace engine

View File

@ -89,7 +89,7 @@ DBImpl::DBImpl(const Options& options)
meta_ptr_ = DBMetaImplFactory::Build(options.meta, options.mode); meta_ptr_ = DBMetaImplFactory::Build(options.meta, options.mode);
mem_mgr_ = MemManagerFactory::Build(meta_ptr_, options_); mem_mgr_ = MemManagerFactory::Build(meta_ptr_, options_);
if (options.mode != Options::MODE::READ_ONLY) { if (options.mode != Options::MODE::READ_ONLY) {
ENGINE_LOG_INFO << "StartTimerTasks"; ENGINE_LOG_TRACE << "StartTimerTasks";
StartTimerTasks(); StartTimerTasks();
} }
@ -102,6 +102,7 @@ Status DBImpl::CreateTable(meta::TableSchema& table_schema) {
Status DBImpl::DeleteTable(const std::string& table_id, const meta::DatesT& dates) { Status DBImpl::DeleteTable(const std::string& table_id, const meta::DatesT& dates) {
//dates partly delete files of the table but currently we don't support //dates partly delete files of the table but currently we don't support
ENGINE_LOG_DEBUG << "Prepare to delete table " << table_id;
mem_mgr_->EraseMemVector(table_id); //not allow insert mem_mgr_->EraseMemVector(table_id); //not allow insert
meta_ptr_->DeleteTable(table_id); //soft delete table meta_ptr_->DeleteTable(table_id); //soft delete table
@ -132,6 +133,7 @@ Status DBImpl::GetTableRowCount(const std::string& table_id, uint64_t& row_count
Status DBImpl::InsertVectors(const std::string& table_id_, Status DBImpl::InsertVectors(const std::string& table_id_,
uint64_t n, const float* vectors, IDNumbers& vector_ids_) { uint64_t n, const float* vectors, IDNumbers& vector_ids_) {
ENGINE_LOG_DEBUG << "Insert " << n << " vectors to cache";
auto start_time = METRICS_NOW_TIME; auto start_time = METRICS_NOW_TIME;
Status status = mem_mgr_->InsertVectors(table_id_, n, vectors, vector_ids_); Status status = mem_mgr_->InsertVectors(table_id_, n, vectors, vector_ids_);
@ -140,6 +142,8 @@ Status DBImpl::InsertVectors(const std::string& table_id_,
// std::chrono::microseconds time_span = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time); // std::chrono::microseconds time_span = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time);
// double average_time = double(time_span.count()) / n; // double average_time = double(time_span.count()) / n;
ENGINE_LOG_DEBUG << "Insert vectors to cache finished";
CollectInsertMetrics(total_time, n, status.ok()); CollectInsertMetrics(total_time, n, status.ok());
return status; return status;
@ -160,6 +164,8 @@ Status DBImpl::Query(const std::string &table_id, uint64_t k, uint64_t nq,
Status DBImpl::Query(const std::string& table_id, uint64_t k, uint64_t nq, Status DBImpl::Query(const std::string& table_id, uint64_t k, uint64_t nq,
const float* vectors, const meta::DatesT& dates, QueryResults& results) { const float* vectors, const meta::DatesT& dates, QueryResults& results) {
ENGINE_LOG_DEBUG << "Query by vectors";
//get all table files from table //get all table files from table
meta::DatePartionedTableFilesSchema files; meta::DatePartionedTableFilesSchema files;
auto status = meta_ptr_->FilesToSearch(table_id, dates, files); auto status = meta_ptr_->FilesToSearch(table_id, dates, files);
@ -181,6 +187,8 @@ Status DBImpl::Query(const std::string& table_id, uint64_t k, uint64_t nq,
Status DBImpl::Query(const std::string& table_id, const std::vector<std::string>& file_ids, Status DBImpl::Query(const std::string& table_id, const std::vector<std::string>& file_ids,
uint64_t k, uint64_t nq, const float* vectors, uint64_t k, uint64_t nq, const float* vectors,
const meta::DatesT& dates, QueryResults& results) { const meta::DatesT& dates, QueryResults& results) {
ENGINE_LOG_DEBUG << "Query by file ids";
//get specified files //get specified files
std::vector<size_t> ids; std::vector<size_t> ids;
for (auto &id : file_ids) { for (auto &id : file_ids) {
@ -269,6 +277,8 @@ void DBImpl::BackgroundTimerTask() {
for(auto& iter : index_thread_results_) { for(auto& iter : index_thread_results_) {
iter.wait(); iter.wait();
} }
ENGINE_LOG_DEBUG << "DB background thread exit";
break; break;
} }
@ -287,6 +297,8 @@ void DBImpl::StartMetricTask() {
return; return;
} }
ENGINE_LOG_TRACE << "Start metric task";
server::Metrics::GetInstance().KeepingAliveCounterIncrement(METRIC_ACTION_INTERVAL); server::Metrics::GetInstance().KeepingAliveCounterIncrement(METRIC_ACTION_INTERVAL);
int64_t cache_usage = cache::CpuCacheMgr::GetInstance()->CacheUsage(); int64_t cache_usage = cache::CpuCacheMgr::GetInstance()->CacheUsage();
int64_t cache_total = cache::CpuCacheMgr::GetInstance()->CacheCapacity(); int64_t cache_total = cache::CpuCacheMgr::GetInstance()->CacheCapacity();
@ -299,17 +311,14 @@ void DBImpl::StartMetricTask() {
server::Metrics::GetInstance().GPUPercentGaugeSet(); server::Metrics::GetInstance().GPUPercentGaugeSet();
server::Metrics::GetInstance().GPUMemoryUsageGaugeSet(); server::Metrics::GetInstance().GPUMemoryUsageGaugeSet();
server::Metrics::GetInstance().OctetsSet(); server::Metrics::GetInstance().OctetsSet();
ENGINE_LOG_TRACE << "Metric task finished";
} }
void DBImpl::StartCompactionTask() { void DBImpl::StartCompactionTask() {
// static int count = 0;
// count++;
// std::cout << "StartCompactionTask: " << count << std::endl;
// std::cout << "c: " << count++ << std::endl;
static uint64_t compact_clock_tick = 0; static uint64_t compact_clock_tick = 0;
compact_clock_tick++; compact_clock_tick++;
if(compact_clock_tick%COMPACT_ACTION_INTERVAL != 0) { if(compact_clock_tick%COMPACT_ACTION_INTERVAL != 0) {
// std::cout << "c r: " << count++ << std::endl;
return; return;
} }
@ -320,6 +329,10 @@ void DBImpl::StartCompactionTask() {
compact_table_ids_.insert(id); compact_table_ids_.insert(id);
} }
if(!temp_table_ids.empty()) {
SERVER_LOG_DEBUG << "Insert cache serialized";
}
//compactiong has been finished? //compactiong has been finished?
if(!compact_thread_results_.empty()) { if(!compact_thread_results_.empty()) {
std::chrono::milliseconds span(10); std::chrono::milliseconds span(10);
@ -338,13 +351,15 @@ void DBImpl::StartCompactionTask() {
Status DBImpl::MergeFiles(const std::string& table_id, const meta::DateT& date, Status DBImpl::MergeFiles(const std::string& table_id, const meta::DateT& date,
const meta::TableFilesSchema& files) { const meta::TableFilesSchema& files) {
ENGINE_LOG_DEBUG << "Merge files for table" << table_id;
meta::TableFileSchema table_file; meta::TableFileSchema table_file;
table_file.table_id_ = table_id; table_file.table_id_ = table_id;
table_file.date_ = date; table_file.date_ = date;
Status status = meta_ptr_->CreateTableFile(table_file); Status status = meta_ptr_->CreateTableFile(table_file);
if (!status.ok()) { if (!status.ok()) {
ENGINE_LOG_INFO << status.ToString() << std::endl; ENGINE_LOG_ERROR << "Failed to create table: " << status.ToString();
return status; return status;
} }
@ -396,6 +411,7 @@ Status DBImpl::BackgroundMergeFiles(const std::string& table_id) {
meta::DatePartionedTableFilesSchema raw_files; meta::DatePartionedTableFilesSchema raw_files;
auto status = meta_ptr_->FilesToMerge(table_id, raw_files); auto status = meta_ptr_->FilesToMerge(table_id, raw_files);
if (!status.ok()) { if (!status.ok()) {
ENGINE_LOG_ERROR << "Failed to get merge files for table: " << table_id;
return status; return status;
} }
@ -417,12 +433,14 @@ Status DBImpl::BackgroundMergeFiles(const std::string& table_id) {
} }
void DBImpl::BackgroundCompaction(std::set<std::string> table_ids) { void DBImpl::BackgroundCompaction(std::set<std::string> table_ids) {
ENGINE_LOG_TRACE << " Background compaction thread start";
Status status; Status status;
for (auto& table_id : table_ids) { for (auto& table_id : table_ids) {
status = BackgroundMergeFiles(table_id); status = BackgroundMergeFiles(table_id);
if (!status.ok()) { if (!status.ok()) {
ENGINE_LOG_ERROR << "Merge files for table " << table_id << " failed: " << status.ToString(); ENGINE_LOG_ERROR << "Merge files for table " << table_id << " failed: " << status.ToString();
return; continue;//let other table get chance to merge
} }
} }
@ -433,6 +451,8 @@ void DBImpl::BackgroundCompaction(std::set<std::string> table_ids) {
ttl = meta::D_SEC; ttl = meta::D_SEC;
} }
meta_ptr_->CleanUpFilesWithTTL(ttl); meta_ptr_->CleanUpFilesWithTTL(ttl);
ENGINE_LOG_TRACE << " Background compaction thread exit";
} }
void DBImpl::StartBuildIndexTask(bool force) { void DBImpl::StartBuildIndexTask(bool force) {
@ -477,6 +497,7 @@ Status DBImpl::BuildIndex(const std::string& table_id) {
Status DBImpl::BuildIndex(const meta::TableFileSchema& file) { Status DBImpl::BuildIndex(const meta::TableFileSchema& file) {
ExecutionEnginePtr to_index = EngineFactory::Build(file.dimension_, file.location_, (EngineType)file.engine_type_); ExecutionEnginePtr to_index = EngineFactory::Build(file.dimension_, file.location_, (EngineType)file.engine_type_);
if(to_index == nullptr) { if(to_index == nullptr) {
ENGINE_LOG_ERROR << "Invalid engine type";
return Status::Error("Invalid engine type"); return Status::Error("Invalid engine type");
} }
@ -491,6 +512,7 @@ Status DBImpl::BuildIndex(const meta::TableFileSchema& file) {
table_file.file_type_ = meta::TableFileSchema::INDEX; //for multi-db-path, distribute index file averagely to each path table_file.file_type_ = meta::TableFileSchema::INDEX; //for multi-db-path, distribute index file averagely to each path
Status status = meta_ptr_->CreateTableFile(table_file); Status status = meta_ptr_->CreateTableFile(table_file);
if (!status.ok()) { if (!status.ok()) {
ENGINE_LOG_ERROR << "Failed to create table: " << status.ToString();
return status; return status;
} }
@ -559,6 +581,8 @@ Status DBImpl::BuildIndexByTable(const std::string& table_id) {
} }
void DBImpl::BackgroundBuildIndex() { void DBImpl::BackgroundBuildIndex() {
ENGINE_LOG_TRACE << " Background build index thread start";
std::unique_lock<std::mutex> lock(build_index_mutex_); std::unique_lock<std::mutex> lock(build_index_mutex_);
meta::TableFilesSchema to_index_files; meta::TableFilesSchema to_index_files;
meta_ptr_->FilesToIndex(to_index_files); meta_ptr_->FilesToIndex(to_index_files);
@ -574,6 +598,8 @@ void DBImpl::BackgroundBuildIndex() {
break; break;
} }
} }
ENGINE_LOG_TRACE << " Background build index thread exit";
} }
Status DBImpl::DropAll() { Status DBImpl::DropAll() {

View File

@ -8,67 +8,88 @@
#include "Meta.h" #include "Meta.h"
#include "Options.h" #include "Options.h"
namespace zilliz { namespace zilliz {
namespace milvus { namespace milvus {
namespace engine { namespace engine {
namespace meta { namespace meta {
auto StoragePrototype(const std::string& path); auto StoragePrototype(const std::string &path);
class DBMetaImpl : public Meta { class DBMetaImpl : public Meta {
public: public:
DBMetaImpl(const DBMetaOptions& options_); explicit DBMetaImpl(const DBMetaOptions &options_);
virtual Status CreateTable(TableSchema& table_schema) override; Status
virtual Status DescribeTable(TableSchema& group_info_) override; CreateTable(TableSchema &table_schema) override;
virtual Status HasTable(const std::string& table_id, bool& has_or_not) override;
virtual Status AllTables(std::vector<TableSchema>& table_schema_array) override;
virtual Status DeleteTable(const std::string& table_id) override; Status
virtual Status DeleteTableFiles(const std::string& table_id) override; DescribeTable(TableSchema &group_info_) override;
virtual Status CreateTableFile(TableFileSchema& file_schema) override; Status
virtual Status DropPartitionsByDates(const std::string& table_id, HasTable(const std::string &table_id, bool &has_or_not) override;
const DatesT& dates) override;
virtual Status GetTableFiles(const std::string& table_id, Status
const std::vector<size_t>& ids, AllTables(std::vector<TableSchema> &table_schema_array) override;
TableFilesSchema& table_files) override;
virtual Status HasNonIndexFiles(const std::string& table_id, bool& has) override; Status
DeleteTable(const std::string &table_id) override;
virtual Status UpdateTableFilesToIndex(const std::string& table_id) override; Status
DeleteTableFiles(const std::string &table_id) override;
virtual Status UpdateTableFile(TableFileSchema& file_schema) override; Status
CreateTableFile(TableFileSchema &file_schema) override;
virtual Status UpdateTableFiles(TableFilesSchema& files) override; Status
DropPartitionsByDates(const std::string &table_id, const DatesT &dates) override;
virtual Status FilesToSearch(const std::string& table_id, Status
const DatesT& partition, GetTableFiles(const std::string &table_id, const std::vector<size_t> &ids, TableFilesSchema &table_files) override;
DatePartionedTableFilesSchema& files) override;
virtual Status FilesToMerge(const std::string& table_id, Status
DatePartionedTableFilesSchema& files) override; HasNonIndexFiles(const std::string &table_id, bool &has) override;
virtual Status FilesToIndex(TableFilesSchema&) override; Status
UpdateTableFilesToIndex(const std::string &table_id) override;
virtual Status Archive() override; Status
UpdateTableFile(TableFileSchema &file_schema) override;
virtual Status Size(uint64_t& result) override; Status
UpdateTableFiles(TableFilesSchema &files) override;
virtual Status CleanUp() override; Status
FilesToSearch(const std::string &table_id, const DatesT &partition, DatePartionedTableFilesSchema &files) override;
virtual Status CleanUpFilesWithTTL(uint16_t seconds) override; Status
FilesToMerge(const std::string &table_id, DatePartionedTableFilesSchema &files) override;
virtual Status DropAll() override; Status
FilesToIndex(TableFilesSchema &) override;
virtual Status Count(const std::string& table_id, uint64_t& result) override; Status
Archive() override;
virtual ~DBMetaImpl(); Status
Size(uint64_t &result) override;
private: Status
Status NextFileId(std::string& file_id); CleanUp() override;
Status NextTableId(std::string& table_id);
Status
CleanUpFilesWithTTL(uint16_t seconds) override;
Status
DropAll() override;
Status Count(const std::string &table_id, uint64_t &result) override;
~DBMetaImpl() override;
private:
Status NextFileId(std::string &file_id);
Status NextTableId(std::string &table_id);
Status DiscardFiles(long to_discard_size); Status DiscardFiles(long to_discard_size);
Status Initialize(); Status Initialize();

View File

@ -13,7 +13,9 @@ namespace zilliz {
namespace milvus { namespace milvus {
namespace engine { namespace engine {
IDGenerator::~IDGenerator() {} IDGenerator::~IDGenerator() = default;
constexpr size_t SimpleIDGenerator::MAX_IDS_PER_MICRO;
IDNumber SimpleIDGenerator::GetNextIDNumber() { IDNumber SimpleIDGenerator::GetNextIDNumber() {
auto now = std::chrono::system_clock::now(); auto now = std::chrono::system_clock::now();

View File

@ -10,28 +10,39 @@
#include <cstddef> #include <cstddef>
#include <vector> #include <vector>
namespace zilliz { namespace zilliz {
namespace milvus { namespace milvus {
namespace engine { namespace engine {
class IDGenerator { class IDGenerator {
public: public:
virtual IDNumber GetNextIDNumber() = 0; virtual
virtual void GetNextIDNumbers(size_t n, IDNumbers& ids) = 0; IDNumber GetNextIDNumber() = 0;
virtual ~IDGenerator(); virtual void
GetNextIDNumbers(size_t n, IDNumbers &ids) = 0;
virtual
~IDGenerator() = 0;
}; // IDGenerator }; // IDGenerator
class SimpleIDGenerator : public IDGenerator { class SimpleIDGenerator : public IDGenerator {
public: public:
virtual IDNumber GetNextIDNumber() override; ~SimpleIDGenerator() override = default;
virtual void GetNextIDNumbers(size_t n, IDNumbers& ids) override;
private: IDNumber
void NextIDNumbers(size_t n, IDNumbers& ids); GetNextIDNumber() override;
const size_t MAX_IDS_PER_MICRO = 1000;
void
GetNextIDNumbers(size_t n, IDNumbers &ids) override;
private:
void
NextIDNumbers(size_t n, IDNumbers &ids);
static constexpr size_t MAX_IDS_PER_MICRO = 1000;
}; // SimpleIDGenerator }; // SimpleIDGenerator

View File

@ -13,6 +13,8 @@ namespace milvus {
namespace engine { namespace engine {
namespace meta { namespace meta {
Meta::~Meta() = default;
DateT Meta::GetDate(const std::time_t& t, int day_delta) { DateT Meta::GetDate(const std::time_t& t, int day_delta) {
struct tm ltm; struct tm ltm;
localtime_r(&t, &ltm); localtime_r(&t, &ltm);

View File

@ -20,56 +20,86 @@ namespace meta {
class Meta { class Meta {
public: public:
using Ptr = std::shared_ptr<Meta>; using Ptr = std::shared_ptr<Meta>;
virtual Status CreateTable(TableSchema& table_schema) = 0; virtual
virtual Status DescribeTable(TableSchema& table_schema) = 0; ~Meta() = 0;
virtual Status HasTable(const std::string& table_id, bool& has_or_not) = 0;
virtual Status AllTables(std::vector<TableSchema>& table_schema_array) = 0;
virtual Status DeleteTable(const std::string& table_id) = 0; virtual Status
virtual Status DeleteTableFiles(const std::string& table_id) = 0; CreateTable(TableSchema &table_schema) = 0;
virtual Status CreateTableFile(TableFileSchema& file_schema) = 0; virtual Status
virtual Status DropPartitionsByDates(const std::string& table_id, DescribeTable(TableSchema &table_schema) = 0;
const DatesT& dates) = 0;
virtual Status GetTableFiles(const std::string& table_id, virtual Status
const std::vector<size_t>& ids, HasTable(const std::string &table_id, bool &has_or_not) = 0;
TableFilesSchema& table_files) = 0;
virtual Status UpdateTableFilesToIndex(const std::string& table_id) = 0; virtual Status
AllTables(std::vector<TableSchema> &table_schema_array) = 0;
virtual Status UpdateTableFile(TableFileSchema& file_schema) = 0; virtual Status
DeleteTable(const std::string &table_id) = 0;
virtual Status UpdateTableFiles(TableFilesSchema& files) = 0; virtual Status
DeleteTableFiles(const std::string &table_id) = 0;
virtual Status FilesToSearch(const std::string &table_id, virtual Status
const DatesT &partition, CreateTableFile(TableFileSchema &file_schema) = 0;
DatePartionedTableFilesSchema& files) = 0;
virtual Status FilesToMerge(const std::string& table_id, virtual Status
DatePartionedTableFilesSchema& files) = 0; DropPartitionsByDates(const std::string &table_id, const DatesT &dates) = 0;
virtual Status Size(uint64_t& result) = 0; virtual Status
GetTableFiles(const std::string &table_id, const std::vector<size_t> &ids, TableFilesSchema &table_files) = 0;
virtual Status Archive() = 0; virtual Status
UpdateTableFilesToIndex(const std::string &table_id) = 0;
virtual Status FilesToIndex(TableFilesSchema&) = 0; virtual Status
UpdateTableFile(TableFileSchema &file_schema) = 0;
virtual Status HasNonIndexFiles(const std::string& table_id, bool& has) = 0; virtual Status
UpdateTableFiles(TableFilesSchema &files) = 0;
virtual Status CleanUp() = 0; virtual Status
virtual Status CleanUpFilesWithTTL(uint16_t) = 0; FilesToSearch(const std::string &table_id, const DatesT &partition, DatePartionedTableFilesSchema &files) = 0;
virtual Status DropAll() = 0; virtual Status
FilesToMerge(const std::string &table_id, DatePartionedTableFilesSchema &files) = 0;
virtual Status Count(const std::string& table_id, uint64_t& result) = 0; virtual Status
Size(uint64_t &result) = 0;
static DateT GetDate(const std::time_t& t, int day_delta = 0); virtual Status
static DateT GetDate(); Archive() = 0;
static DateT GetDateWithDelta(int day_delta);
virtual Status
FilesToIndex(TableFilesSchema &) = 0;
virtual Status
HasNonIndexFiles(const std::string &table_id, bool &has) = 0;
virtual Status
CleanUp() = 0;
virtual Status
CleanUpFilesWithTTL(uint16_t) = 0;
virtual Status
DropAll() = 0;
virtual Status
Count(const std::string &table_id, uint64_t &result) = 0;
static DateT
GetDate(const std::time_t &t, int day_delta = 0);
static DateT
GetDate();
static DateT
GetDateWithDelta(int day_delta);
}; // MetaData }; // MetaData

View File

@ -12,79 +12,80 @@
#include "mysql++/mysql++.h" #include "mysql++/mysql++.h"
#include <mutex> #include <mutex>
namespace zilliz { namespace zilliz {
namespace milvus { namespace milvus {
namespace engine { namespace engine {
namespace meta { namespace meta {
// auto StoragePrototype(const std::string& path); // auto StoragePrototype(const std::string& path);
using namespace mysqlpp; using namespace mysqlpp;
class MySQLMetaImpl : public Meta { class MySQLMetaImpl : public Meta {
public: public:
MySQLMetaImpl(const DBMetaOptions& options_, const int& mode); MySQLMetaImpl(const DBMetaOptions &options_, const int &mode);
virtual Status CreateTable(TableSchema& table_schema) override; Status CreateTable(TableSchema &table_schema) override;
virtual Status DescribeTable(TableSchema& group_info_) override; Status DescribeTable(TableSchema &group_info_) override;
virtual Status HasTable(const std::string& table_id, bool& has_or_not) override; Status HasTable(const std::string &table_id, bool &has_or_not) override;
virtual Status AllTables(std::vector<TableSchema>& table_schema_array) override; Status AllTables(std::vector<TableSchema> &table_schema_array) override;
virtual Status DeleteTable(const std::string& table_id) override; Status DeleteTable(const std::string &table_id) override;
virtual Status DeleteTableFiles(const std::string& table_id) override; Status DeleteTableFiles(const std::string &table_id) override;
virtual Status CreateTableFile(TableFileSchema& file_schema) override; Status CreateTableFile(TableFileSchema &file_schema) override;
virtual Status DropPartitionsByDates(const std::string& table_id, Status DropPartitionsByDates(const std::string &table_id,
const DatesT& dates) override; const DatesT &dates) override;
virtual Status GetTableFiles(const std::string& table_id, Status GetTableFiles(const std::string &table_id,
const std::vector<size_t>& ids, const std::vector<size_t> &ids,
TableFilesSchema& table_files) override; TableFilesSchema &table_files) override;
virtual Status HasNonIndexFiles(const std::string& table_id, bool& has) override; Status HasNonIndexFiles(const std::string &table_id, bool &has) override;
virtual Status UpdateTableFile(TableFileSchema& file_schema) override; Status UpdateTableFile(TableFileSchema &file_schema) override;
virtual Status UpdateTableFilesToIndex(const std::string& table_id) override; Status UpdateTableFilesToIndex(const std::string &table_id) override;
virtual Status UpdateTableFiles(TableFilesSchema& files) override; Status UpdateTableFiles(TableFilesSchema &files) override;
virtual Status FilesToSearch(const std::string& table_id, Status FilesToSearch(const std::string &table_id,
const DatesT& partition, const DatesT &partition,
DatePartionedTableFilesSchema& files) override; DatePartionedTableFilesSchema &files) override;
virtual Status FilesToMerge(const std::string& table_id, Status FilesToMerge(const std::string &table_id,
DatePartionedTableFilesSchema& files) override; DatePartionedTableFilesSchema &files) override;
virtual Status FilesToIndex(TableFilesSchema&) override; Status FilesToIndex(TableFilesSchema &) override;
virtual Status Archive() override; Status Archive() override;
virtual Status Size(uint64_t& result) override; Status Size(uint64_t &result) override;
virtual Status CleanUp() override; Status CleanUp() override;
virtual Status CleanUpFilesWithTTL(uint16_t seconds) override; Status CleanUpFilesWithTTL(uint16_t seconds) override;
virtual Status DropAll() override; Status DropAll() override;
virtual Status Count(const std::string& table_id, uint64_t& result) override; Status Count(const std::string &table_id, uint64_t &result) override;
virtual ~MySQLMetaImpl(); virtual ~MySQLMetaImpl();
private: private:
Status NextFileId(std::string& file_id); Status NextFileId(std::string &file_id);
Status NextTableId(std::string& table_id); Status NextTableId(std::string &table_id);
Status DiscardFiles(long long to_discard_size); Status DiscardFiles(long long to_discard_size);
Status Initialize(); Status Initialize();
const DBMetaOptions options_; const DBMetaOptions options_;
const int mode_; const int mode_;
std::shared_ptr<MySQLConnectionPool> mysql_connection_pool_; std::shared_ptr<MySQLConnectionPool> mysql_connection_pool_;
bool safe_grab = false; bool safe_grab = false;
// std::mutex connectionMutex_; // std::mutex connectionMutex_;
}; // DBMetaImpl }; // DBMetaImpl
} // namespace meta } // namespace meta
} // namespace engine } // namespace engine

View File

@ -20,6 +20,7 @@ class ReuseCacheIndexStrategy {
public: public:
bool Schedule(const SearchContextPtr &context, std::list<ScheduleTaskPtr>& task_list) { bool Schedule(const SearchContextPtr &context, std::list<ScheduleTaskPtr>& task_list) {
if(context == nullptr) { if(context == nullptr) {
ENGINE_LOG_ERROR << "Task Dispatch context doesn't exist";
return false; return false;
} }
@ -64,6 +65,7 @@ class DeleteTableStrategy {
public: public:
bool Schedule(const DeleteContextPtr &context, std::list<ScheduleTaskPtr> &task_list) { bool Schedule(const DeleteContextPtr &context, std::list<ScheduleTaskPtr> &task_list) {
if (context == nullptr) { if (context == nullptr) {
ENGINE_LOG_ERROR << "Task Dispatch context doesn't exist";
return false; return false;
} }
@ -103,6 +105,7 @@ public:
bool TaskDispatchStrategy::Schedule(const ScheduleContextPtr &context_ptr, bool TaskDispatchStrategy::Schedule(const ScheduleContextPtr &context_ptr,
std::list<zilliz::milvus::engine::ScheduleTaskPtr> &task_list) { std::list<zilliz::milvus::engine::ScheduleTaskPtr> &task_list) {
if(context_ptr == nullptr) { if(context_ptr == nullptr) {
ENGINE_LOG_ERROR << "Task Dispatch context doesn't exist";
return false; return false;
} }

View File

@ -31,6 +31,7 @@ TaskScheduler& TaskScheduler::GetInstance() {
bool bool
TaskScheduler::Start() { TaskScheduler::Start() {
if(!stopped_) { if(!stopped_) {
SERVER_LOG_INFO << "Task Scheduler isn't started";
return true; return true;
} }
@ -47,6 +48,7 @@ TaskScheduler::Start() {
bool bool
TaskScheduler::Stop() { TaskScheduler::Stop() {
if(stopped_) { if(stopped_) {
SERVER_LOG_INFO << "Task Scheduler already stopped";
return true; return true;
} }
@ -80,7 +82,7 @@ TaskScheduler::TaskDispatchWorker() {
ScheduleTaskPtr task_ptr = task_dispatch_queue_.Take(); ScheduleTaskPtr task_ptr = task_dispatch_queue_.Take();
if(task_ptr == nullptr) { if(task_ptr == nullptr) {
SERVER_LOG_INFO << "Stop db task dispatch thread"; SERVER_LOG_INFO << "Stop db task dispatch thread";
break;//exit return true;
} }
//execute task //execute task
@ -98,8 +100,8 @@ TaskScheduler::TaskWorker() {
while(true) { while(true) {
ScheduleTaskPtr task_ptr = task_queue_.Take(); ScheduleTaskPtr task_ptr = task_queue_.Take();
if(task_ptr == nullptr) { if(task_ptr == nullptr) {
SERVER_LOG_INFO << "Stop db task thread"; SERVER_LOG_INFO << "Stop db task worker thread";
break;//exit return true;
} }
//execute task //execute task

View File

@ -5,14 +5,60 @@
******************************************************************************/ ******************************************************************************/
#include "SearchTask.h" #include "SearchTask.h"
#include "metrics/Metrics.h" #include "metrics/Metrics.h"
#include "utils/Log.h" #include "db/Log.h"
#include "utils/TimeRecorder.h" #include "utils/TimeRecorder.h"
#include <thread>
namespace zilliz { namespace zilliz {
namespace milvus { namespace milvus {
namespace engine { namespace engine {
namespace { namespace {
static constexpr size_t PARALLEL_REDUCE_THRESHOLD = 10000;
static constexpr size_t PARALLEL_REDUCE_BATCH = 1000;
bool NeedParallelReduce(uint64_t nq, uint64_t topk) {
server::ServerConfig &config = server::ServerConfig::GetInstance();
server::ConfigNode& db_config = config.GetConfig(server::CONFIG_DB);
bool need_parallel = db_config.GetBoolValue(server::CONFIG_DB_PARALLEL_REDUCE, true);
if(!need_parallel) {
return false;
}
return nq*topk >= PARALLEL_REDUCE_THRESHOLD;
}
void ParallelReduce(std::function<void(size_t, size_t)>& reduce_function, size_t max_index) {
size_t reduce_batch = PARALLEL_REDUCE_BATCH;
auto thread_count = std::thread::hardware_concurrency() - 1; //not all core do this work
if(thread_count > 0) {
reduce_batch = max_index/thread_count + 1;
}
ENGINE_LOG_DEBUG << "use " << thread_count <<
" thread parallelly do reduce, each thread process " << reduce_batch << " vectors";
std::vector<std::shared_ptr<std::thread> > thread_array;
size_t from_index = 0;
while(from_index < max_index) {
size_t to_index = from_index + reduce_batch;
if(to_index > max_index) {
to_index = max_index;
}
auto reduce_thread = std::make_shared<std::thread>(reduce_function, from_index, to_index);
thread_array.push_back(reduce_thread);
from_index = to_index;
}
for(auto& thread_ptr : thread_array) {
thread_ptr->join();
}
}
void CollectDurationMetrics(int index_type, double total_time) { void CollectDurationMetrics(int index_type, double total_time) {
switch(index_type) { switch(index_type) {
case meta::TableFileSchema::RAW: { case meta::TableFileSchema::RAW: {
@ -32,7 +78,7 @@ void CollectDurationMetrics(int index_type, double total_time) {
std::string GetMetricType() { std::string GetMetricType() {
server::ServerConfig &config = server::ServerConfig::GetInstance(); server::ServerConfig &config = server::ServerConfig::GetInstance();
server::ConfigNode engine_config = config.GetConfig(server::CONFIG_ENGINE); server::ConfigNode& engine_config = config.GetConfig(server::CONFIG_ENGINE);
return engine_config.GetValue(server::CONFIG_METRICTYPE, "L2"); return engine_config.GetValue(server::CONFIG_METRICTYPE, "L2");
} }
@ -51,7 +97,7 @@ std::shared_ptr<IScheduleTask> SearchTask::Execute() {
return nullptr; return nullptr;
} }
SERVER_LOG_DEBUG << "Searching in file id:" << index_id_<< " with " ENGINE_LOG_DEBUG << "Searching in file id:" << index_id_<< " with "
<< search_contexts_.size() << " tasks"; << search_contexts_.size() << " tasks";
server::TimeRecorder rc("DoSearch file id:" + std::to_string(index_id_)); server::TimeRecorder rc("DoSearch file id:" + std::to_string(index_id_));
@ -79,6 +125,9 @@ std::shared_ptr<IScheduleTask> SearchTask::Execute() {
auto spec_k = index_engine_->Count() < context->topk() ? index_engine_->Count() : context->topk(); auto spec_k = index_engine_->Count() < context->topk() ? index_engine_->Count() : context->topk();
SearchTask::ClusterResult(output_ids, output_distence, context->nq(), spec_k, result_set); SearchTask::ClusterResult(output_ids, output_distence, context->nq(), spec_k, result_set);
span = rc.RecordSection("cluster result for context:" + context->Identity());
context->AccumReduceCost(span);
//step 4: pick up topk result //step 4: pick up topk result
SearchTask::TopkResult(result_set, inner_k, metric_l2, context->GetResult()); SearchTask::TopkResult(result_set, inner_k, metric_l2, context->GetResult());
@ -86,7 +135,7 @@ std::shared_ptr<IScheduleTask> SearchTask::Execute() {
context->AccumReduceCost(span); context->AccumReduceCost(span);
} catch (std::exception& ex) { } catch (std::exception& ex) {
SERVER_LOG_ERROR << "SearchTask encounter exception: " << ex.what(); ENGINE_LOG_ERROR << "SearchTask encounter exception: " << ex.what();
context->IndexSearchDone(index_id_);//mark as done avoid dead lock, even search failed context->IndexSearchDone(index_id_);//mark as done avoid dead lock, even search failed
continue; continue;
} }
@ -112,23 +161,32 @@ Status SearchTask::ClusterResult(const std::vector<long> &output_ids,
if(output_ids.size() < nq*topk || output_distence.size() < nq*topk) { if(output_ids.size() < nq*topk || output_distence.size() < nq*topk) {
std::string msg = "Invalid id array size: " + std::to_string(output_ids.size()) + std::string msg = "Invalid id array size: " + std::to_string(output_ids.size()) +
" distance array size: " + std::to_string(output_distence.size()); " distance array size: " + std::to_string(output_distence.size());
SERVER_LOG_ERROR << msg; ENGINE_LOG_ERROR << msg;
return Status::Error(msg); return Status::Error(msg);
} }
result_set.clear(); result_set.clear();
result_set.reserve(nq); result_set.resize(nq);
for (auto i = 0; i < nq; i++) {
SearchContext::Id2DistanceMap id_distance; std::function<void(size_t, size_t)> reduce_worker = [&](size_t from_index, size_t to_index) {
id_distance.reserve(topk); for (auto i = from_index; i < to_index; i++) {
for (auto k = 0; k < topk; k++) { SearchContext::Id2DistanceMap id_distance;
uint64_t index = i * topk + k; id_distance.reserve(topk);
if(output_ids[index] < 0) { for (auto k = 0; k < topk; k++) {
continue; uint64_t index = i * topk + k;
if(output_ids[index] < 0) {
continue;
}
id_distance.push_back(std::make_pair(output_ids[index], output_distence[index]));
} }
id_distance.push_back(std::make_pair(output_ids[index], output_distence[index])); result_set[i] = id_distance;
} }
result_set.emplace_back(id_distance); };
if(NeedParallelReduce(nq, topk)) {
ParallelReduce(reduce_worker, nq);
} else {
reduce_worker(0, nq);
} }
return Status::OK(); return Status::OK();
@ -140,7 +198,7 @@ Status SearchTask::MergeResult(SearchContext::Id2DistanceMap &distance_src,
bool ascending) { bool ascending) {
//Note: the score_src and score_target are already arranged by score in ascending order //Note: the score_src and score_target are already arranged by score in ascending order
if(distance_src.empty()) { if(distance_src.empty()) {
SERVER_LOG_WARNING << "Empty distance source array"; ENGINE_LOG_WARNING << "Empty distance source array";
return Status::OK(); return Status::OK();
} }
@ -218,14 +276,22 @@ Status SearchTask::TopkResult(SearchContext::ResultSet &result_src,
if (result_src.size() != result_target.size()) { if (result_src.size() != result_target.size()) {
std::string msg = "Invalid result set size"; std::string msg = "Invalid result set size";
SERVER_LOG_ERROR << msg; ENGINE_LOG_ERROR << msg;
return Status::Error(msg); return Status::Error(msg);
} }
for (size_t i = 0; i < result_src.size(); i++) { std::function<void(size_t, size_t)> ReduceWorker = [&](size_t from_index, size_t to_index) {
SearchContext::Id2DistanceMap &score_src = result_src[i]; for (size_t i = from_index; i < to_index; i++) {
SearchContext::Id2DistanceMap &score_target = result_target[i]; SearchContext::Id2DistanceMap &score_src = result_src[i];
SearchTask::MergeResult(score_src, score_target, topk, ascending); SearchContext::Id2DistanceMap &score_target = result_target[i];
SearchTask::MergeResult(score_src, score_target, topk, ascending);
}
};
if(NeedParallelReduce(result_src.size(), topk)) {
ParallelReduce(ReduceWorker, result_src.size());
} else {
ReduceWorker(0, result_src.size());
} }
return Status::OK(); return Status::OK();

View File

@ -233,21 +233,22 @@ ClientTest::Test(const std::string& address, const std::string& port) {
PrintTableSchema(tb_schema); PrintTableSchema(tb_schema);
} }
//add vectors
std::vector<std::pair<int64_t, RowRecord>> search_record_array; std::vector<std::pair<int64_t, RowRecord>> search_record_array;
{//add vectors for (int i = 0; i < ADD_VECTOR_LOOP; i++) {
for (int i = 0; i < ADD_VECTOR_LOOP; i++) {//add vectors TimeRecorder recorder("Add vector No." + std::to_string(i));
TimeRecorder recorder("Add vector No." + std::to_string(i)); std::vector<RowRecord> record_array;
std::vector<RowRecord> record_array; int64_t begin_index = i * BATCH_ROW_COUNT;
int64_t begin_index = i * BATCH_ROW_COUNT; BuildVectors(begin_index, begin_index + BATCH_ROW_COUNT, record_array);
BuildVectors(begin_index, begin_index + BATCH_ROW_COUNT, record_array); std::vector<int64_t> record_ids;
std::vector<int64_t> record_ids; Status stat = conn->AddVector(TABLE_NAME, record_array, record_ids);
Status stat = conn->AddVector(TABLE_NAME, record_array, record_ids); std::cout << "AddVector function call status: " << stat.ToString() << std::endl;
std::cout << "AddVector function call status: " << stat.ToString() << std::endl; std::cout << "Returned id array count: " << record_ids.size() << std::endl;
std::cout << "Returned id array count: " << record_ids.size() << std::endl;
if(search_record_array.size() < NQ) { if(i == 0) {
for(int64_t k = SEARCH_TARGET; k < SEARCH_TARGET + NQ; k++) {
search_record_array.push_back( search_record_array.push_back(
std::make_pair(record_ids[SEARCH_TARGET], record_array[SEARCH_TARGET])); std::make_pair(record_ids[k], record_array[k]));
} }
} }
} }

View File

@ -191,6 +191,7 @@ ServerError CreateTableTask::OnExecute() {
} }
} catch (std::exception& ex) { } catch (std::exception& ex) {
SERVER_LOG_ERROR << "CreateTableTask encounter exception: " << ex.what();
return SetError(SERVER_UNEXPECTED_ERROR, ex.what()); return SetError(SERVER_UNEXPECTED_ERROR, ex.what());
} }
@ -236,6 +237,7 @@ ServerError DescribeTableTask::OnExecute() {
schema_.store_raw_vector = table_info.store_raw_data_; schema_.store_raw_vector = table_info.store_raw_data_;
} catch (std::exception& ex) { } catch (std::exception& ex) {
SERVER_LOG_ERROR << "DescribeTableTask encounter exception: " << ex.what();
return SetError(SERVER_UNEXPECTED_ERROR, ex.what()); return SetError(SERVER_UNEXPECTED_ERROR, ex.what());
} }
@ -279,6 +281,7 @@ ServerError BuildIndexTask::OnExecute() {
rc.ElapseFromBegin("totally cost"); rc.ElapseFromBegin("totally cost");
} catch (std::exception& ex) { } catch (std::exception& ex) {
SERVER_LOG_ERROR << "BuildIndexTask encounter exception: " << ex.what();
return SetError(SERVER_UNEXPECTED_ERROR, ex.what()); return SetError(SERVER_UNEXPECTED_ERROR, ex.what());
} }
@ -316,6 +319,7 @@ ServerError HasTableTask::OnExecute() {
rc.ElapseFromBegin("totally cost"); rc.ElapseFromBegin("totally cost");
} catch (std::exception& ex) { } catch (std::exception& ex) {
SERVER_LOG_ERROR << "HasTableTask encounter exception: " << ex.what();
return SetError(SERVER_UNEXPECTED_ERROR, ex.what()); return SetError(SERVER_UNEXPECTED_ERROR, ex.what());
} }
@ -365,6 +369,7 @@ ServerError DeleteTableTask::OnExecute() {
rc.ElapseFromBegin("totally cost"); rc.ElapseFromBegin("totally cost");
} catch (std::exception& ex) { } catch (std::exception& ex) {
SERVER_LOG_ERROR << "DeleteTableTask encounter exception: " << ex.what();
return SetError(SERVER_UNEXPECTED_ERROR, ex.what()); return SetError(SERVER_UNEXPECTED_ERROR, ex.what());
} }
@ -481,6 +486,7 @@ ServerError AddVectorTask::OnExecute() {
rc.ElapseFromBegin("totally cost"); rc.ElapseFromBegin("totally cost");
} catch (std::exception& ex) { } catch (std::exception& ex) {
SERVER_LOG_ERROR << "AddVectorTask encounter exception: " << ex.what();
return SetError(SERVER_UNEXPECTED_ERROR, ex.what()); return SetError(SERVER_UNEXPECTED_ERROR, ex.what());
} }
@ -604,6 +610,7 @@ ServerError SearchVectorTaskBase::OnExecute() {
<< " construct result(" << (span_result/total_cost)*100.0 << "%)"; << " construct result(" << (span_result/total_cost)*100.0 << "%)";
} catch (std::exception& ex) { } catch (std::exception& ex) {
SERVER_LOG_ERROR << "SearchVectorTask encounter exception: " << ex.what();
return SetError(SERVER_UNEXPECTED_ERROR, ex.what()); return SetError(SERVER_UNEXPECTED_ERROR, ex.what());
} }
@ -739,6 +746,7 @@ ServerError GetTableRowCountTask::OnExecute() {
rc.ElapseFromBegin("totally cost"); rc.ElapseFromBegin("totally cost");
} catch (std::exception& ex) { } catch (std::exception& ex) {
SERVER_LOG_ERROR << "GetTableRowCountTask encounter exception: " << ex.what();
return SetError(SERVER_UNEXPECTED_ERROR, ex.what()); return SetError(SERVER_UNEXPECTED_ERROR, ex.what());
} }

View File

@ -29,6 +29,7 @@ static const std::string CONFIG_DB_INDEX_TRIGGER_SIZE = "index_building_threshol
static const std::string CONFIG_DB_ARCHIVE_DISK = "archive_disk_threshold"; static const std::string CONFIG_DB_ARCHIVE_DISK = "archive_disk_threshold";
static const std::string CONFIG_DB_ARCHIVE_DAYS = "archive_days_threshold"; static const std::string CONFIG_DB_ARCHIVE_DAYS = "archive_days_threshold";
static const std::string CONFIG_DB_INSERT_BUFFER_SIZE = "insert_buffer_size"; static const std::string CONFIG_DB_INSERT_BUFFER_SIZE = "insert_buffer_size";
static const std::string CONFIG_DB_PARALLEL_REDUCE = "parallel_reduce";
static const std::string CONFIG_LOG = "log_config"; static const std::string CONFIG_LOG = "log_config";

View File

@ -6,6 +6,8 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "db/scheduler/task/SearchTask.h" #include "db/scheduler/task/SearchTask.h"
#include "utils/TimeRecorder.h"
#include <cmath> #include <cmath>
#include <vector> #include <vector>
@ -17,27 +19,33 @@ static constexpr uint64_t NQ = 15;
static constexpr uint64_t TOP_K = 64; static constexpr uint64_t TOP_K = 64;
void BuildResult(uint64_t nq, void BuildResult(uint64_t nq,
uint64_t top_k, uint64_t topk,
bool ascending,
std::vector<long> &output_ids, std::vector<long> &output_ids,
std::vector<float> &output_distence) { std::vector<float> &output_distence) {
output_ids.clear(); output_ids.clear();
output_ids.resize(nq*top_k); output_ids.resize(nq*topk);
output_distence.clear(); output_distence.clear();
output_distence.resize(nq*top_k); output_distence.resize(nq*topk);
for(uint64_t i = 0; i < nq; i++) { for(uint64_t i = 0; i < nq; i++) {
for(uint64_t j = 0; j < top_k; j++) { for(uint64_t j = 0; j < topk; j++) {
output_ids[i * top_k + j] = (long)(drand48()*100000); output_ids[i * topk + j] = (long)(drand48()*100000);
output_distence[i * top_k + j] = j + drand48(); output_distence[i * topk + j] = ascending ? (j + drand48()) : ((topk - j) + drand48());
} }
} }
} }
void CheckResult(const engine::SearchContext::Id2DistanceMap& src_1, void CheckResult(const engine::SearchContext::Id2DistanceMap& src_1,
const engine::SearchContext::Id2DistanceMap& src_2, const engine::SearchContext::Id2DistanceMap& src_2,
const engine::SearchContext::Id2DistanceMap& target) { const engine::SearchContext::Id2DistanceMap& target,
bool ascending) {
for(uint64_t i = 0; i < target.size() - 1; i++) { for(uint64_t i = 0; i < target.size() - 1; i++) {
ASSERT_LE(target[i].second, target[i + 1].second); if(ascending) {
ASSERT_LE(target[i].second, target[i + 1].second);
} else {
ASSERT_GE(target[i].second, target[i + 1].second);
}
} }
using ID2DistMap = std::map<long, float>; using ID2DistMap = std::map<long, float>;
@ -57,9 +65,52 @@ void CheckResult(const engine::SearchContext::Id2DistanceMap& src_1,
} }
} }
void CheckCluster(const std::vector<long>& target_ids,
const std::vector<float>& target_distence,
const engine::SearchContext::ResultSet& src_result,
int64_t nq,
int64_t topk) {
ASSERT_EQ(src_result.size(), nq);
for(int64_t i = 0; i < nq; i++) {
auto& res = src_result[i];
ASSERT_EQ(res.size(), topk);
if(res.empty()) {
continue;
}
ASSERT_EQ(res[0].first, target_ids[i*topk]);
ASSERT_EQ(res[topk - 1].first, target_ids[i*topk + topk - 1]);
}
}
void CheckTopkResult(const engine::SearchContext::ResultSet& src_result,
bool ascending,
int64_t nq,
int64_t topk) {
ASSERT_EQ(src_result.size(), nq);
for(int64_t i = 0; i < nq; i++) {
auto& res = src_result[i];
ASSERT_EQ(res.size(), topk);
if(res.empty()) {
continue;
}
for(int64_t k = 0; k < topk - 1; k++) {
if(ascending) {
ASSERT_LE(res[k].second, res[k + 1].second);
} else {
ASSERT_GE(res[k].second, res[k + 1].second);
}
}
}
}
} }
TEST(DBSearchTest, TOPK_TEST) { TEST(DBSearchTest, TOPK_TEST) {
bool ascending = true;
std::vector<long> target_ids; std::vector<long> target_ids;
std::vector<float> target_distence; std::vector<float> target_distence;
engine::SearchContext::ResultSet src_result; engine::SearchContext::ResultSet src_result;
@ -67,19 +118,19 @@ TEST(DBSearchTest, TOPK_TEST) {
ASSERT_FALSE(status.ok()); ASSERT_FALSE(status.ok());
ASSERT_TRUE(src_result.empty()); ASSERT_TRUE(src_result.empty());
BuildResult(NQ, TOP_K, target_ids, target_distence); BuildResult(NQ, TOP_K, ascending, target_ids, target_distence);
status = engine::SearchTask::ClusterResult(target_ids, target_distence, NQ, TOP_K, src_result); status = engine::SearchTask::ClusterResult(target_ids, target_distence, NQ, TOP_K, src_result);
ASSERT_TRUE(status.ok()); ASSERT_TRUE(status.ok());
ASSERT_EQ(src_result.size(), NQ); ASSERT_EQ(src_result.size(), NQ);
engine::SearchContext::ResultSet target_result; engine::SearchContext::ResultSet target_result;
status = engine::SearchTask::TopkResult(target_result, TOP_K, true, target_result); status = engine::SearchTask::TopkResult(target_result, TOP_K, ascending, target_result);
ASSERT_TRUE(status.ok()); ASSERT_TRUE(status.ok());
status = engine::SearchTask::TopkResult(target_result, TOP_K, true, src_result); status = engine::SearchTask::TopkResult(target_result, TOP_K, ascending, src_result);
ASSERT_FALSE(status.ok()); ASSERT_FALSE(status.ok());
status = engine::SearchTask::TopkResult(src_result, TOP_K, true, target_result); status = engine::SearchTask::TopkResult(src_result, TOP_K, ascending, target_result);
ASSERT_TRUE(status.ok()); ASSERT_TRUE(status.ok());
ASSERT_TRUE(src_result.empty()); ASSERT_TRUE(src_result.empty());
ASSERT_EQ(target_result.size(), NQ); ASSERT_EQ(target_result.size(), NQ);
@ -87,21 +138,21 @@ TEST(DBSearchTest, TOPK_TEST) {
std::vector<long> src_ids; std::vector<long> src_ids;
std::vector<float> src_distence; std::vector<float> src_distence;
uint64_t wrong_topk = TOP_K - 10; uint64_t wrong_topk = TOP_K - 10;
BuildResult(NQ, wrong_topk, src_ids, src_distence); BuildResult(NQ, wrong_topk, ascending, src_ids, src_distence);
status = engine::SearchTask::ClusterResult(src_ids, src_distence, NQ, wrong_topk, src_result); status = engine::SearchTask::ClusterResult(src_ids, src_distence, NQ, wrong_topk, src_result);
ASSERT_TRUE(status.ok()); ASSERT_TRUE(status.ok());
status = engine::SearchTask::TopkResult(src_result, TOP_K, true, target_result); status = engine::SearchTask::TopkResult(src_result, TOP_K, ascending, target_result);
ASSERT_TRUE(status.ok()); ASSERT_TRUE(status.ok());
for(uint64_t i = 0; i < NQ; i++) { for(uint64_t i = 0; i < NQ; i++) {
ASSERT_EQ(target_result[i].size(), TOP_K); ASSERT_EQ(target_result[i].size(), TOP_K);
} }
wrong_topk = TOP_K + 10; wrong_topk = TOP_K + 10;
BuildResult(NQ, wrong_topk, src_ids, src_distence); BuildResult(NQ, wrong_topk, ascending, src_ids, src_distence);
status = engine::SearchTask::TopkResult(src_result, TOP_K, true, target_result); status = engine::SearchTask::TopkResult(src_result, TOP_K, ascending, target_result);
ASSERT_TRUE(status.ok()); ASSERT_TRUE(status.ok());
for(uint64_t i = 0; i < NQ; i++) { for(uint64_t i = 0; i < NQ; i++) {
ASSERT_EQ(target_result[i].size(), TOP_K); ASSERT_EQ(target_result[i].size(), TOP_K);
@ -109,6 +160,7 @@ TEST(DBSearchTest, TOPK_TEST) {
} }
TEST(DBSearchTest, MERGE_TEST) { TEST(DBSearchTest, MERGE_TEST) {
bool ascending = true;
std::vector<long> target_ids; std::vector<long> target_ids;
std::vector<float> target_distence; std::vector<float> target_distence;
std::vector<long> src_ids; std::vector<long> src_ids;
@ -116,8 +168,8 @@ TEST(DBSearchTest, MERGE_TEST) {
engine::SearchContext::ResultSet src_result, target_result; engine::SearchContext::ResultSet src_result, target_result;
uint64_t src_count = 5, target_count = 8; uint64_t src_count = 5, target_count = 8;
BuildResult(1, src_count, src_ids, src_distence); BuildResult(1, src_count, ascending, src_ids, src_distence);
BuildResult(1, target_count, target_ids, target_distence); BuildResult(1, target_count, ascending, target_ids, target_distence);
auto status = engine::SearchTask::ClusterResult(src_ids, src_distence, 1, src_count, src_result); auto status = engine::SearchTask::ClusterResult(src_ids, src_distence, 1, src_count, src_result);
ASSERT_TRUE(status.ok()); ASSERT_TRUE(status.ok());
status = engine::SearchTask::ClusterResult(target_ids, target_distence, 1, target_count, target_result); status = engine::SearchTask::ClusterResult(target_ids, target_distence, 1, target_count, target_result);
@ -126,37 +178,107 @@ TEST(DBSearchTest, MERGE_TEST) {
{ {
engine::SearchContext::Id2DistanceMap src = src_result[0]; engine::SearchContext::Id2DistanceMap src = src_result[0];
engine::SearchContext::Id2DistanceMap target = target_result[0]; engine::SearchContext::Id2DistanceMap target = target_result[0];
status = engine::SearchTask::MergeResult(src, target, 10, true); status = engine::SearchTask::MergeResult(src, target, 10, ascending);
ASSERT_TRUE(status.ok()); ASSERT_TRUE(status.ok());
ASSERT_EQ(target.size(), 10); ASSERT_EQ(target.size(), 10);
CheckResult(src_result[0], target_result[0], target); CheckResult(src_result[0], target_result[0], target, ascending);
} }
{ {
engine::SearchContext::Id2DistanceMap src = src_result[0]; engine::SearchContext::Id2DistanceMap src = src_result[0];
engine::SearchContext::Id2DistanceMap target; engine::SearchContext::Id2DistanceMap target;
status = engine::SearchTask::MergeResult(src, target, 10, true); status = engine::SearchTask::MergeResult(src, target, 10, ascending);
ASSERT_TRUE(status.ok()); ASSERT_TRUE(status.ok());
ASSERT_EQ(target.size(), src_count); ASSERT_EQ(target.size(), src_count);
ASSERT_TRUE(src.empty()); ASSERT_TRUE(src.empty());
CheckResult(src_result[0], target_result[0], target); CheckResult(src_result[0], target_result[0], target, ascending);
} }
{ {
engine::SearchContext::Id2DistanceMap src = src_result[0]; engine::SearchContext::Id2DistanceMap src = src_result[0];
engine::SearchContext::Id2DistanceMap target = target_result[0]; engine::SearchContext::Id2DistanceMap target = target_result[0];
status = engine::SearchTask::MergeResult(src, target, 30, true); status = engine::SearchTask::MergeResult(src, target, 30, ascending);
ASSERT_TRUE(status.ok()); ASSERT_TRUE(status.ok());
ASSERT_EQ(target.size(), src_count + target_count); ASSERT_EQ(target.size(), src_count + target_count);
CheckResult(src_result[0], target_result[0], target); CheckResult(src_result[0], target_result[0], target, ascending);
} }
{ {
engine::SearchContext::Id2DistanceMap target = src_result[0]; engine::SearchContext::Id2DistanceMap target = src_result[0];
engine::SearchContext::Id2DistanceMap src = target_result[0]; engine::SearchContext::Id2DistanceMap src = target_result[0];
status = engine::SearchTask::MergeResult(src, target, 30, true); status = engine::SearchTask::MergeResult(src, target, 30, ascending);
ASSERT_TRUE(status.ok()); ASSERT_TRUE(status.ok());
ASSERT_EQ(target.size(), src_count + target_count); ASSERT_EQ(target.size(), src_count + target_count);
CheckResult(src_result[0], target_result[0], target); CheckResult(src_result[0], target_result[0], target, ascending);
} }
} }
TEST(DBSearchTest, PARALLEL_CLUSTER_TEST) {
bool ascending = true;
std::vector<long> target_ids;
std::vector<float> target_distence;
engine::SearchContext::ResultSet src_result;
auto DoCluster = [&](int64_t nq, int64_t topk) {
server::TimeRecorder rc("DoCluster");
src_result.clear();
BuildResult(nq, topk, ascending, target_ids, target_distence);
rc.RecordSection("build id/dietance map");
auto status = engine::SearchTask::ClusterResult(target_ids, target_distence, nq, topk, src_result);
ASSERT_TRUE(status.ok());
ASSERT_EQ(src_result.size(), nq);
rc.RecordSection("cluster result");
CheckCluster(target_ids, target_distence, src_result, nq, topk);
rc.RecordSection("check result");
};
DoCluster(10000, 1000);
DoCluster(333, 999);
DoCluster(1, 1000);
DoCluster(1, 1);
DoCluster(7, 0);
DoCluster(9999, 1);
DoCluster(10001, 1);
DoCluster(58273, 1234);
}
TEST(DBSearchTest, PARALLEL_TOPK_TEST) {
std::vector<long> target_ids;
std::vector<float> target_distence;
engine::SearchContext::ResultSet src_result;
std::vector<long> insufficient_ids;
std::vector<float> insufficient_distence;
engine::SearchContext::ResultSet insufficient_result;
auto DoTopk = [&](int64_t nq, int64_t topk,int64_t insufficient_topk, bool ascending) {
src_result.clear();
insufficient_result.clear();
server::TimeRecorder rc("DoCluster");
BuildResult(nq, topk, ascending, target_ids, target_distence);
auto status = engine::SearchTask::ClusterResult(target_ids, target_distence, nq, topk, src_result);
rc.RecordSection("cluster result");
BuildResult(nq, insufficient_topk, ascending, insufficient_ids, insufficient_distence);
status = engine::SearchTask::ClusterResult(target_ids, target_distence, nq, insufficient_topk, insufficient_result);
rc.RecordSection("cluster result");
engine::SearchTask::TopkResult(insufficient_result, topk, ascending, src_result);
ASSERT_TRUE(status.ok());
rc.RecordSection("topk");
CheckTopkResult(src_result, ascending, nq, topk);
rc.RecordSection("check result");
};
DoTopk(5, 10, 4, false);
DoTopk(20005, 998, 123, true);
DoTopk(9987, 12, 10, false);
DoTopk(77777, 1000, 1, false);
DoTopk(5432, 8899, 8899, true);
}