MQThreadPool

This commit is contained in:
fasiondog 2021-03-04 23:23:23 +08:00
parent a52404516f
commit b46961382f
3 changed files with 67 additions and 66 deletions

View File

@ -1,5 +1,5 @@
/*
* MQThreadPool.h
* StealMQThreadPool.h
*
* Copyright (c) 2019 hikyuu.org
*
@ -8,51 +8,49 @@
*/
#pragma once
#ifndef HIKYUU_UTILITIES_THREAD_STEALTHREADPOOL_H
#define HIKYUU_UTILITIES_THREAD_STEALTHREADPOOL_H
#ifndef HIKYUU_UTILITIES_MQTHREAD_THREADPOOL_H
#define HIKYUU_UTILITIES_MQTHREAD_THREADPOOL_H
//#include <fmt/format.h>
#include <future>
#include <thread>
#include <chrono>
#include <vector>
#include "FuncWrapper.h"
#include "ThreadSafeQueue.h"
#include "WorkStealQueue.h"
namespace hku {
/**
* @brief 线()
* @note 线使线
* @brief 线
* @note 使 StealThreadPool
* @details
* @ingroup ThreadPool
* @ingroup MQThreadPool
*/
class MQThreadPool {
public:
/**
* CPU数一致的线程数
*/
MQThreadPool() : MQThreadPool(std::thread::hardware_concurrency()) {}
MQThreadPool() : MQThreadPool(std::thread::hardware_concurrency(), true) {}
/**
* 线
* @param n 线
*/
explicit MQThreadPool(size_t n)
: m_done(false), m_init_finished(false), m_worker_num(n), m_current_alloc_index(0) {
explicit MQThreadPool(size_t n, bool util_empty = true)
: m_done(false), m_worker_num(n), m_runnging_util_empty(util_empty) {
try {
for (size_t i = 0; i < m_worker_num; i++) {
// 创建工作线程及其任务队列
m_queues.push_back(std::unique_ptr<WorkStealQueue>(new WorkStealQueue));
m_queues.push_back(
std::unique_ptr<ThreadSafeQueue<task_type>>(new ThreadSafeQueue<task_type>));
m_threads.push_back(std::thread(&MQThreadPool::worker_thread, this, i));
m_cvs.push_back(std::make_unique<std::condition_variable>());
m_cv_mutexs.push_back(std::make_unique<std::mutex>());
}
} catch (...) {
m_done = true;
throw;
}
m_init_finished = true;
}
/**
@ -84,15 +82,22 @@ public:
typedef typename std::result_of<FunctionType()>::type result_type;
std::packaged_task<result_type()> task(f);
task_handle<result_type> res(task.get_future());
if (m_local_work_queue) {
// 本地线程任务从前部入队列(递归成栈)
m_local_work_queue->push_front(std::move(task));
} else {
m_queues[m_current_alloc_index++].push(std::move(task));
if (m_current_alloc_index >= m_worker_num) {
m_current_alloc_index = 0;
size_t min_count = std::numeric_limits<size_t>::max();
int index = -1;
for (int i = 0; i < m_worker_num; ++i) {
size_t cur_count = m_queues[i]->size();
if (cur_count == 0) {
index = i;
break;
}
if (cur_count < min_count) {
min_count = cur_count;
index = i;
}
}
m_queues[index]->push(std::move(task));
return res;
}
@ -113,10 +118,9 @@ public:
// 同时加入结束任务指示以便在dll退出时也能够终止
for (size_t i = 0; i < m_worker_num; i++) {
m_queues[i]->push_front(std::move(FuncWrapper()));
m_queues[i]->push(std::move(FuncWrapper()));
}
m_cv.notify_all(); // 唤醒所有工作线程
for (size_t i = 0; i < m_worker_num; i++) {
if (m_threads[i].joinable()) {
m_threads[i].join();
@ -130,13 +134,12 @@ public:
*/
void join() {
// 指示各工作线程在未获取到工作任务时,停止运行
for (size_t i = 0; i < m_worker_num; i++) {
m_master_work_queue.push(std::move(FuncWrapper()));
if (m_runnging_util_empty) {
for (size_t i = 0; i < m_worker_num; i++) {
m_queues[i]->push(std::move(FuncWrapper()));
}
}
// 唤醒所有工作线程
m_cv.notify_all();
// 等待线程结束
for (size_t i = 0; i < m_worker_num; i++) {
if (m_threads[i].joinable()) {
@ -149,18 +152,16 @@ public:
private:
typedef FuncWrapper task_type;
std::atomic_bool m_done; // 线程池全局需终止指示
bool m_init_finished; // 线程池是否初始化完毕
size_t m_worker_num; // 工作线程数量
std::atomic_bool m_done; // 线程池全局需终止指示
size_t m_worker_num; // 工作线程数量
bool m_runnging_util_empty; // 运行直到队列空时停止
std::vector<std::unique_ptr<WorkStealQueue>> m_queues; // 任务队列(每个工作线程一个)
std::vector<std::thread> m_threads; // 工作线程
std::vector<std::unique_ptr<std::condition_variable>> m_cvs; // 信号量,无任务时阻塞线程并等待
std::vector<std::unique_ptr<std::mutex>> m_cv_mutexs; // 配合信号量的互斥量
size_t m_current_alloc_index; // 当前分配的线程号
std::vector<std::unique_ptr<ThreadSafeQueue<task_type>>> m_queues; // 线程任务队列
std::vector<std::thread> m_threads; // 工作线程
// 线程本地变量
inline static thread_local WorkStealQueue* m_local_work_queue = nullptr; // 本地任务队列
inline static thread_local ThreadSafeQueue<task_type>* m_local_work_queue =
nullptr; // 本地任务队列
inline static thread_local size_t m_index = 0; //在线程池中的序号
inline static thread_local bool m_thread_need_stop = false; // 线程停止运行指示
@ -168,46 +169,25 @@ private:
m_thread_need_stop = false;
m_index = index;
m_local_work_queue = m_queues[m_index].get();
while (!m_thread_need_stop && !m_done) {
while (!m_done && !m_thread_need_stop) {
run_pending_task();
std::this_thread::yield();
}
// fmt::print("thread ({}) finished!\n", std::this_thread::get_id());
}
void run_pending_task() {
// 从本地队列提前工作任务,如本地无任务则从主队列中提取任务
// 如果主队列中提取的任务是空任务,则认为需结束本线程,否则从其他工作队列中偷取任务
task_type task;
if (pop_task_from_local_queue(task)) {
task();
std::this_thread::yield();
} else if (pop_task_from_other_thread_queue(task)) {
task();
std::this_thread::yield();
m_local_work_queue->wait_and_pop(task);
if (task.isNullTask()) {
m_thread_need_stop = true;
} else {
std::this_thread::yield();
task();
}
}
bool pop_task_from_local_queue(task_type& task) {
return m_local_work_queue && m_local_work_queue->try_pop(task);
}
bool pop_task_from_other_thread_queue(task_type& task) {
// 线程池尚未初始化化完成时,其他任务队列可能尚未创建
// 此时不能从其他队列偷取任务
if (!m_init_finished) {
return false;
}
for (size_t i = 0; i < m_worker_num; ++i) {
size_t index = (m_index + i + 1) % m_worker_num;
if (m_queues[index]->try_steal(task)) {
return true;
}
}
return false;
}
}; // namespace hku
} /* namespace hku */
#endif /* HIKYUU_UTILITIES_THREAD_STEALTHREADPOOL_H */
#endif /* HIKYUU_UTILITIES_MQTHREAD_THREADPOOL_H */

View File

@ -68,6 +68,11 @@ public:
return m_queue.empty();
}
// 队列大小,无锁
size_t size() const {
return m_queue.size();
}
private:
mutable std::mutex m_mutex;
std::queue<T> m_queue;

View File

@ -10,6 +10,7 @@
#include "doctest/doctest.h"
#include <hikyuu/utilities/thread/StealThreadPool.h>
#include <hikyuu/utilities/thread/ThreadPool.h>
#include <hikyuu/utilities/thread/MQThreadPool.h>
#include <hikyuu/utilities/SpendTimer.h>
#include <hikyuu/Log.h>
@ -36,6 +37,21 @@ TEST_CASE("test_ThreadPool") {
}
}
/** @par 检测点 */
TEST_CASE("test_MQThreadPool") {
{
SPEND_TIME(test_MQThreadPool);
MQThreadPool tg(8);
HKU_INFO("worker_num: {}", tg.worker_num());
for (int i = 0; i < 10; i++) {
tg.submit([=]() { // fmt::print("{}: ----------------------\n", i);
HKU_INFO("{}: ------------------- [{}]", i, std::this_thread::get_id());
});
}
tg.join();
}
}
/** @par 检测点 */
TEST_CASE("test_StealThreadPool") {
{