diff --git a/internal/core/src/storage/ThreadPool.cpp b/internal/core/src/storage/ThreadPool.cpp index d287dc1389..b81c99c2b7 100644 --- a/internal/core/src/storage/ThreadPool.cpp +++ b/internal/core/src/storage/ThreadPool.cpp @@ -20,21 +20,80 @@ namespace milvus { void ThreadPool::Init() { - for (int i = 0; i < threads_.size(); i++) { - threads_[i] = std::thread(Worker(this, i)); + std::lock_guard lock(mutex_); + for (int i = 0; i < min_threads_size_; i++) { + std::thread t(&ThreadPool::Worker, this); + assert(threads_.find(t.get_id()) == threads_.end()); + threads_[t.get_id()] = std::move(t); + current_threads_size_++; } } void ThreadPool::ShutDown() { LOG_SEGCORE_INFO_ << "Start shutting down " << name_; - shutdown_ = true; + { + std::lock_guard lock(mutex_); + shutdown_ = true; + } condition_lock_.notify_all(); - for (int i = 0; i < threads_.size(); i++) { - if (threads_[i].joinable()) { - threads_[i].join(); + for (auto iter = threads_.begin(); iter != threads_.end(); ++iter) { + if (iter->second.joinable()) { + iter->second.join(); } } LOG_SEGCORE_INFO_ << "Finish shutting down " << name_; } + +void +ThreadPool::FinishThreads() { + while (!need_finish_threads_.empty()) { + std::thread::id id; + auto dequeue = need_finish_threads_.dequeue(id); + if (dequeue) { + auto iter = threads_.find(id); + assert(iter != threads_.end()); + if (iter->second.joinable()) { + iter->second.join(); + } + threads_.erase(iter); + } + } +} + +void +ThreadPool::Worker() { + std::function func; + bool dequeue; + while (!shutdown_) { + std::unique_lock lock(mutex_); + idle_threads_size_++; + auto is_timeout = !condition_lock_.wait_for( + lock, std::chrono::seconds(WAIT_SECONDS), [this]() { + return shutdown_ || !work_queue_.empty(); + }); + idle_threads_size_--; + if (work_queue_.empty()) { + // Dynamic reduce thread number + if (shutdown_) { + current_threads_size_--; + return; + } + if (is_timeout) { + FinishThreads(); + if (current_threads_size_ > min_threads_size_) { + need_finish_threads_.enqueue(std::this_thread::get_id()); + current_threads_size_--; + return; + } + continue; + } + } + dequeue = work_queue_.dequeue(func); + lock.unlock(); + if (dequeue) { + func(); + } + } +} }; // namespace milvus diff --git a/internal/core/src/storage/ThreadPool.h b/internal/core/src/storage/ThreadPool.h index 0736e1a23f..dc62c2d0ca 100644 --- a/internal/core/src/storage/ThreadPool.h +++ b/internal/core/src/storage/ThreadPool.h @@ -36,10 +36,13 @@ class ThreadPool { explicit ThreadPool(const int thread_core_coefficient, const std::string& name) : shutdown_(false), name_(name) { - auto thread_num = CPU_NUM * thread_core_coefficient; - threads_ = std::vector(thread_num); + idle_threads_size_ = 0; + current_threads_size_ = 0; + min_threads_size_ = CPU_NUM; + max_threads_size_ = CPU_NUM * thread_core_coefficient; LOG_SEGCORE_INFO_ << "Init thread pool:" << name_ - << " with worker num:" << thread_num; + << " with min worker num:" << min_threads_size_ + << " and max worker num:" << max_threads_size_; Init(); } @@ -60,6 +63,12 @@ class ThreadPool { void ShutDown(); + size_t + GetThreadNum() { + std::lock_guard lock(mutex_); + return current_threads_size_; + } + template auto // Submit(F&& f, Args&&... args) -> std::future; @@ -73,15 +82,37 @@ class ThreadPool { work_queue_.enqueue(wrap_func); - condition_lock_.notify_one(); + std::lock_guard lock(mutex_); + + if (idle_threads_size_ > 0) { + condition_lock_.notify_one(); + } else if (current_threads_size_ < max_threads_size_) { + // Dynamic increase thread number + std::thread t(&ThreadPool::Worker, this); + assert(threads_.find(t.get_id()) == threads_.end()); + threads_[t.get_id()] = std::move(t); + current_threads_size_++; + } return task_ptr->get_future(); } + void + Worker(); + + void + FinishThreads(); + public: + int min_threads_size_; + int idle_threads_size_; + int current_threads_size_; + int max_threads_size_; bool shutdown_; + static constexpr size_t WAIT_SECONDS = 2; SafeQueue> work_queue_; - std::vector threads_; + std::unordered_map threads_; + SafeQueue need_finish_threads_; std::mutex mutex_; std::condition_variable condition_lock_; std::string name_; diff --git a/internal/core/unittest/test_disk_file_manager_test.cpp b/internal/core/unittest/test_disk_file_manager_test.cpp index dc6b970e88..146a34f775 100644 --- a/internal/core/unittest/test_disk_file_manager_test.cpp +++ b/internal/core/unittest/test_disk_file_manager_test.cpp @@ -106,8 +106,44 @@ test_worker(string s) { return 1; } +int +compute(int a) { + return a + 10; +} + +TEST_F(DiskAnnFileManagerTest, TestThreadPoolBase) { + auto thread_pool = std::make_shared(10, "test1"); + std::cout << "current thread num" << thread_pool->GetThreadNum() + << std::endl; + auto thread_num_1 = thread_pool->GetThreadNum(); + EXPECT_GT(thread_num_1, 0); + + auto fut = thread_pool->Submit(compute, 10); + auto res = fut.get(); + EXPECT_EQ(res, 20); + + std::vector> futs; + for (int i = 0; i < 10; ++i) { + futs.push_back(thread_pool->Submit(compute, i)); + } + std::cout << "current thread num" << thread_pool->GetThreadNum() + << std::endl; + auto thread_num_2 = thread_pool->GetThreadNum(); + EXPECT_GT(thread_num_2, thread_num_1); + + for (int i = 0; i < 10; ++i) { + std::cout << futs[i].get() << std::endl; + } + + sleep(5); + std::cout << "current thread num" << thread_pool->GetThreadNum() + << std::endl; + auto thread_num_3 = thread_pool->GetThreadNum(); + EXPECT_LT(thread_num_3, thread_num_2); +} + TEST_F(DiskAnnFileManagerTest, TestThreadPool) { - auto thread_pool = new milvus::ThreadPool(50, "test"); + auto thread_pool = std::make_shared(50, "test"); std::vector> futures; auto start = chrono::system_clock::now(); for (int i = 0; i < 100; i++) { @@ -121,6 +157,7 @@ TEST_F(DiskAnnFileManagerTest, TestThreadPool) { auto duration = chrono::duration_cast(end - start); auto second = double(duration.count()) * chrono::microseconds::period::num / chrono::microseconds::period::den; + std::cout << "cost time:" << second << std::endl; EXPECT_LT(second, 4 * 100); } @@ -134,7 +171,7 @@ test_exception(string s) { TEST_F(DiskAnnFileManagerTest, TestThreadPoolException) { try { - auto thread_pool = new milvus::ThreadPool(50, "test"); + auto thread_pool = std::make_shared(50, "test"); std::vector> futures; for (int i = 0; i < 100; i++) { futures.push_back(thread_pool->Submit(