mirror of
https://gitee.com/fasiondog/hikyuu.git
synced 2024-12-02 20:08:26 +08:00
MQThreadPool
This commit is contained in:
parent
a52404516f
commit
b46961382f
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* MQThreadPool.h
|
||||
* StealMQThreadPool.h
|
||||
*
|
||||
* Copyright (c) 2019 hikyuu.org
|
||||
*
|
||||
@ -8,51 +8,49 @@
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#ifndef HIKYUU_UTILITIES_THREAD_STEALTHREADPOOL_H
|
||||
#define HIKYUU_UTILITIES_THREAD_STEALTHREADPOOL_H
|
||||
#ifndef HIKYUU_UTILITIES_MQTHREAD_THREADPOOL_H
|
||||
#define HIKYUU_UTILITIES_MQTHREAD_THREADPOOL_H
|
||||
|
||||
//#include <fmt/format.h>
|
||||
#include <future>
|
||||
#include <thread>
|
||||
#include <chrono>
|
||||
#include <vector>
|
||||
#include "FuncWrapper.h"
|
||||
#include "ThreadSafeQueue.h"
|
||||
#include "WorkStealQueue.h"
|
||||
|
||||
namespace hku {
|
||||
|
||||
/**
|
||||
* @brief 分布式偷取式线程池(无中央队列)
|
||||
* @note 主要用于存在递归情况,任务又创建任务加入线程池的情况,否则建议使用普通的线程池
|
||||
* @brief 普通多任务队列线程池,任务之间彼此独立不能互相等待
|
||||
* @note 任务运行之间如存在先后顺序,请使用 StealThreadPool。
|
||||
* @details
|
||||
* @ingroup ThreadPool
|
||||
* @ingroup MQThreadPool
|
||||
*/
|
||||
class MQThreadPool {
|
||||
public:
|
||||
/**
|
||||
* 默认构造函数,创建和当前系统CPU数一致的线程数
|
||||
*/
|
||||
MQThreadPool() : MQThreadPool(std::thread::hardware_concurrency()) {}
|
||||
MQThreadPool() : MQThreadPool(std::thread::hardware_concurrency(), true) {}
|
||||
|
||||
/**
|
||||
* 构造函数,创建指定数量的线程
|
||||
* @param n 指定的线程数
|
||||
*/
|
||||
explicit MQThreadPool(size_t n)
|
||||
: m_done(false), m_init_finished(false), m_worker_num(n), m_current_alloc_index(0) {
|
||||
explicit MQThreadPool(size_t n, bool util_empty = true)
|
||||
: m_done(false), m_worker_num(n), m_runnging_util_empty(util_empty) {
|
||||
try {
|
||||
for (size_t i = 0; i < m_worker_num; i++) {
|
||||
// 创建工作线程及其任务队列
|
||||
m_queues.push_back(std::unique_ptr<WorkStealQueue>(new WorkStealQueue));
|
||||
m_queues.push_back(
|
||||
std::unique_ptr<ThreadSafeQueue<task_type>>(new ThreadSafeQueue<task_type>));
|
||||
m_threads.push_back(std::thread(&MQThreadPool::worker_thread, this, i));
|
||||
m_cvs.push_back(std::make_unique<std::condition_variable>());
|
||||
m_cv_mutexs.push_back(std::make_unique<std::mutex>());
|
||||
}
|
||||
} catch (...) {
|
||||
m_done = true;
|
||||
throw;
|
||||
}
|
||||
m_init_finished = true;
|
||||
}
|
||||
|
||||
/**
|
||||
@ -84,15 +82,22 @@ public:
|
||||
typedef typename std::result_of<FunctionType()>::type result_type;
|
||||
std::packaged_task<result_type()> task(f);
|
||||
task_handle<result_type> res(task.get_future());
|
||||
if (m_local_work_queue) {
|
||||
// 本地线程任务从前部入队列(递归成栈)
|
||||
m_local_work_queue->push_front(std::move(task));
|
||||
} else {
|
||||
m_queues[m_current_alloc_index++].push(std::move(task));
|
||||
if (m_current_alloc_index >= m_worker_num) {
|
||||
m_current_alloc_index = 0;
|
||||
size_t min_count = std::numeric_limits<size_t>::max();
|
||||
int index = -1;
|
||||
for (int i = 0; i < m_worker_num; ++i) {
|
||||
size_t cur_count = m_queues[i]->size();
|
||||
if (cur_count == 0) {
|
||||
index = i;
|
||||
break;
|
||||
}
|
||||
|
||||
if (cur_count < min_count) {
|
||||
min_count = cur_count;
|
||||
index = i;
|
||||
}
|
||||
}
|
||||
|
||||
m_queues[index]->push(std::move(task));
|
||||
return res;
|
||||
}
|
||||
|
||||
@ -113,10 +118,9 @@ public:
|
||||
|
||||
// 同时加入结束任务指示,以便在dll退出时也能够终止
|
||||
for (size_t i = 0; i < m_worker_num; i++) {
|
||||
m_queues[i]->push_front(std::move(FuncWrapper()));
|
||||
m_queues[i]->push(std::move(FuncWrapper()));
|
||||
}
|
||||
|
||||
m_cv.notify_all(); // 唤醒所有工作线程
|
||||
for (size_t i = 0; i < m_worker_num; i++) {
|
||||
if (m_threads[i].joinable()) {
|
||||
m_threads[i].join();
|
||||
@ -130,13 +134,12 @@ public:
|
||||
*/
|
||||
void join() {
|
||||
// 指示各工作线程在未获取到工作任务时,停止运行
|
||||
for (size_t i = 0; i < m_worker_num; i++) {
|
||||
m_master_work_queue.push(std::move(FuncWrapper()));
|
||||
if (m_runnging_util_empty) {
|
||||
for (size_t i = 0; i < m_worker_num; i++) {
|
||||
m_queues[i]->push(std::move(FuncWrapper()));
|
||||
}
|
||||
}
|
||||
|
||||
// 唤醒所有工作线程
|
||||
m_cv.notify_all();
|
||||
|
||||
// 等待线程结束
|
||||
for (size_t i = 0; i < m_worker_num; i++) {
|
||||
if (m_threads[i].joinable()) {
|
||||
@ -149,18 +152,16 @@ public:
|
||||
|
||||
private:
|
||||
typedef FuncWrapper task_type;
|
||||
std::atomic_bool m_done; // 线程池全局需终止指示
|
||||
bool m_init_finished; // 线程池是否初始化完毕
|
||||
size_t m_worker_num; // 工作线程数量
|
||||
std::atomic_bool m_done; // 线程池全局需终止指示
|
||||
size_t m_worker_num; // 工作线程数量
|
||||
bool m_runnging_util_empty; // 运行直到队列空时停止
|
||||
|
||||
std::vector<std::unique_ptr<WorkStealQueue>> m_queues; // 任务队列(每个工作线程一个)
|
||||
std::vector<std::thread> m_threads; // 工作线程
|
||||
std::vector<std::unique_ptr<std::condition_variable>> m_cvs; // 信号量,无任务时阻塞线程并等待
|
||||
std::vector<std::unique_ptr<std::mutex>> m_cv_mutexs; // 配合信号量的互斥量
|
||||
size_t m_current_alloc_index; // 当前分配的线程号
|
||||
std::vector<std::unique_ptr<ThreadSafeQueue<task_type>>> m_queues; // 线程任务队列
|
||||
std::vector<std::thread> m_threads; // 工作线程
|
||||
|
||||
// 线程本地变量
|
||||
inline static thread_local WorkStealQueue* m_local_work_queue = nullptr; // 本地任务队列
|
||||
inline static thread_local ThreadSafeQueue<task_type>* m_local_work_queue =
|
||||
nullptr; // 本地任务队列
|
||||
inline static thread_local size_t m_index = 0; //在线程池中的序号
|
||||
inline static thread_local bool m_thread_need_stop = false; // 线程停止运行指示
|
||||
|
||||
@ -168,46 +169,25 @@ private:
|
||||
m_thread_need_stop = false;
|
||||
m_index = index;
|
||||
m_local_work_queue = m_queues[m_index].get();
|
||||
while (!m_thread_need_stop && !m_done) {
|
||||
while (!m_done && !m_thread_need_stop) {
|
||||
run_pending_task();
|
||||
std::this_thread::yield();
|
||||
}
|
||||
// fmt::print("thread ({}) finished!\n", std::this_thread::get_id());
|
||||
}
|
||||
|
||||
void run_pending_task() {
|
||||
// 从本地队列提前工作任务,如本地无任务则从主队列中提取任务
|
||||
// 如果主队列中提取的任务是空任务,则认为需结束本线程,否则从其他工作队列中偷取任务
|
||||
task_type task;
|
||||
if (pop_task_from_local_queue(task)) {
|
||||
task();
|
||||
std::this_thread::yield();
|
||||
} else if (pop_task_from_other_thread_queue(task)) {
|
||||
task();
|
||||
std::this_thread::yield();
|
||||
m_local_work_queue->wait_and_pop(task);
|
||||
if (task.isNullTask()) {
|
||||
m_thread_need_stop = true;
|
||||
} else {
|
||||
std::this_thread::yield();
|
||||
task();
|
||||
}
|
||||
}
|
||||
|
||||
bool pop_task_from_local_queue(task_type& task) {
|
||||
return m_local_work_queue && m_local_work_queue->try_pop(task);
|
||||
}
|
||||
|
||||
bool pop_task_from_other_thread_queue(task_type& task) {
|
||||
// 线程池尚未初始化化完成时,其他任务队列可能尚未创建
|
||||
// 此时不能从其他队列偷取任务
|
||||
if (!m_init_finished) {
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < m_worker_num; ++i) {
|
||||
size_t index = (m_index + i + 1) % m_worker_num;
|
||||
if (m_queues[index]->try_steal(task)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}; // namespace hku
|
||||
|
||||
} /* namespace hku */
|
||||
|
||||
#endif /* HIKYUU_UTILITIES_THREAD_STEALTHREADPOOL_H */
|
||||
#endif /* HIKYUU_UTILITIES_MQTHREAD_THREADPOOL_H */
|
@ -68,6 +68,11 @@ public:
|
||||
return m_queue.empty();
|
||||
}
|
||||
|
||||
// 队列大小,无锁
|
||||
size_t size() const {
|
||||
return m_queue.size();
|
||||
}
|
||||
|
||||
private:
|
||||
mutable std::mutex m_mutex;
|
||||
std::queue<T> m_queue;
|
||||
|
@ -10,6 +10,7 @@
|
||||
#include "doctest/doctest.h"
|
||||
#include <hikyuu/utilities/thread/StealThreadPool.h>
|
||||
#include <hikyuu/utilities/thread/ThreadPool.h>
|
||||
#include <hikyuu/utilities/thread/MQThreadPool.h>
|
||||
#include <hikyuu/utilities/SpendTimer.h>
|
||||
#include <hikyuu/Log.h>
|
||||
|
||||
@ -36,6 +37,21 @@ TEST_CASE("test_ThreadPool") {
|
||||
}
|
||||
}
|
||||
|
||||
/** @par 检测点 */
|
||||
TEST_CASE("test_MQThreadPool") {
|
||||
{
|
||||
SPEND_TIME(test_MQThreadPool);
|
||||
MQThreadPool tg(8);
|
||||
HKU_INFO("worker_num: {}", tg.worker_num());
|
||||
for (int i = 0; i < 10; i++) {
|
||||
tg.submit([=]() { // fmt::print("{}: ----------------------\n", i);
|
||||
HKU_INFO("{}: ------------------- [{}]", i, std::this_thread::get_id());
|
||||
});
|
||||
}
|
||||
tg.join();
|
||||
}
|
||||
}
|
||||
|
||||
/** @par 检测点 */
|
||||
TEST_CASE("test_StealThreadPool") {
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user