StealTask (continue)

This commit is contained in:
fasiondog 2020-04-27 01:58:11 +08:00
parent 0c96fedc65
commit c4d4a799a4
10 changed files with 279 additions and 60 deletions

View File

@ -0,0 +1,58 @@
/*
* StealRunnerQueue.h
*
* Copyright (c) hikyuu.org
*
* Created on: 2020-4-27
* Author: fasiondog
*/
#include "StealRunnerQueue.h"
namespace hku {
/* 将数据插入队列头部 */
void StealRunnerQueue::push_front(const StealTaskPtr& task) {
std::lock_guard<std::mutex> lock(m_mutex);
m_queue.push_front(task);
}
/* 将数据插入队列尾部 */
void StealRunnerQueue::push_back(const StealTaskPtr& task) {
std::lock_guard<std::mutex> lock(m_mutex);
m_queue.push_back(task);
}
/* 队列是否为空 */
bool StealRunnerQueue::empty() const {
std::lock_guard<std::mutex> lock(m_mutex);
return m_queue.empty();
}
/* 尝试从队列头部弹出一条数数据, 如果失败返回空指针 */
StealTaskPtr StealRunnerQueue::try_pop() {
std::lock_guard<std::mutex> lock(m_mutex);
StealTaskPtr result;
if (m_queue.empty()) {
return result;
}
result = m_queue.front();
m_queue.pop_front();
return result;
}
/* 尝试从队列尾部偷取一条数据,失败返回空指针 */
StealTaskPtr StealRunnerQueue::try_steal() {
std::lock_guard<std::mutex> lock(m_mutex);
StealTaskPtr result;
if (m_queue.empty()) {
return result;
}
result = m_queue.back();
m_queue.pop_back();
return result;
}
} /* namespace hku */

View File

@ -0,0 +1,67 @@
/*
* StealRunnerQueue.h
*
* Copyright (c) hikyuu.org
*
* Created on: 2020-4-26
* Author: fasiondog
*/
#pragma once
#ifndef HIKYUU_UTILITIES_TASK_STEAL_RUNNER_QUEUE_H
#define HIKYUU_UTILITIES_TASK_STEAL_RUNNER_QUEUE_H
#include <deque>
#include <mutex>
#include "StealTaskBase.h"
namespace hku {
/**
*
*/
class StealRunnerQueue {
public:
/** 构造函数 */
StealRunnerQueue() = default;
~StealRunnerQueue() = default;
// 禁用赋值构造和赋值重载
StealRunnerQueue(const StealRunnerQueue& other) = delete;
StealRunnerQueue& operator=(const StealRunnerQueue& other) = delete;
/** 将数据插入队列头部 */
void push_front(const StealTaskPtr& task);
/** 将数据插入队列尾部 */
void push_back(const StealTaskPtr& task);
/** 队列是否为空 */
bool empty() const;
size_t size() const {
return m_queue.size();
}
/**
*
* @param res
* @return
*/
StealTaskPtr try_pop();
/**
*
* @param res
* @return
*/
StealTaskPtr try_steal();
private:
std::deque<StealTaskPtr> m_queue;
mutable std::mutex m_mutex;
};
} /* namespace hku */
#endif /* HIKYUU_UTILITIES_TASK_STEAL_RUNNER_QUEUE_H */

View File

@ -6,38 +6,34 @@
*/ */
#include <iostream> #include <iostream>
#include "../../Log.h"
#include "StealTaskBase.h" #include "StealTaskBase.h"
#include "StealTaskRunner.h" #include "StealTaskRunner.h"
namespace hku { namespace hku {
StealTaskBase::StealTaskBase() { StealTaskBase::StealTaskBase() {
_done = false; m_done = false;
_runner = NULL; m_runner = NULL;
} }
StealTaskBase::~StealTaskBase() {} StealTaskBase::~StealTaskBase() {}
void StealTaskBase::fork(StealTaskRunner* runner) {
setTaskRunner(runner);
if (runner) {
runner->putTask(shared_from_this());
} else {
std::cerr << "[TaskBase::fork] Invalid Runner!" << std::endl;
}
}
void StealTaskBase::join() { void StealTaskBase::join() {
if (_runner) { if (m_runner) {
_runner->taskJoin(shared_from_this()); m_runner->taskJoin(shared_from_this());
} else { } else {
std::cerr << "[TaskBase::join] Invalid Runner!" << std::endl; HKU_ERROR("Invalid runner!");
} }
} }
void StealTaskBase::invoke() { void StealTaskBase::invoke() {
run(); run();
_done = true; m_done = true;
}
void StealTaskBase::run() {
HKU_WARN("This is empty task!");
} }
} // namespace hku } // namespace hku

View File

@ -19,36 +19,49 @@
namespace hku { namespace hku {
class StealTaskRunner; class StealTaskRunner;
class StealTaskGroup;
/** /**
* *
* @ingroup TaskGroup * @ingroup TaskGroup
*/ */
class HKU_API StealTaskBase : public std::enable_shared_from_this<StealTaskBase> { class HKU_API StealTaskBase : public std::enable_shared_from_this<StealTaskBase> {
friend class StealTaskRunner;
friend class StealTaskGroup;
public: public:
StealTaskBase(); StealTaskBase();
virtual ~StealTaskBase(); virtual ~StealTaskBase();
virtual void run() = 0; /**
*
*/
virtual void run();
/**
*
*/
bool isDone() const { bool isDone() const {
return _done; return m_done;
} }
void fork(StealTaskRunner *); /**
*
*/
void join(); void join();
private:
// StealTaskRunner 实际执行任务
void invoke(); void invoke();
// StealTaskGroup 设置
void setTaskRunner(StealTaskRunner *runner) { void setTaskRunner(StealTaskRunner *runner) {
_runner = runner; m_runner = runner;
}
StealTaskRunner *getTaskRunner() {
return _runner;
} }
private: private:
mutable bool _done; mutable bool m_done; // 标记该任务是否已执行完毕
mutable StealTaskRunner *_runner; mutable StealTaskRunner *m_runner;
}; };
typedef std::shared_ptr<StealTaskBase> StealTaskPtr; typedef std::shared_ptr<StealTaskBase> StealTaskPtr;

View File

@ -6,6 +6,7 @@
*/ */
#include <iostream> #include <iostream>
#include "../../Log.h"
#include "StealTaskGroup.h" #include "StealTaskGroup.h"
namespace hku { namespace hku {
@ -47,15 +48,19 @@ private:
StealTaskGroup::StealTaskGroup(size_t taskCount, size_t groupSize) { StealTaskGroup::StealTaskGroup(size_t taskCount, size_t groupSize) {
_taskList.reserve(taskCount); _taskList.reserve(taskCount);
_groupSize = (groupSize != 0) ? groupSize : std::thread::hardware_concurrency(); m_runnerNum = (groupSize != 0) ? groupSize : std::thread::hardware_concurrency();
_runnerList.reserve(_groupSize); _runnerList.reserve(m_runnerNum);
_stopTask = StealTaskPtr(new StopTask()); _stopTask = StealTaskPtr(new StopTask());
_currentRunnerId = 0; _currentRunnerId = 0;
for (size_t i = 0; i < _groupSize; i++) { for (size_t i = 0; i < m_runnerNum; i++) {
StealTaskRunnerPtr runner(new StealTaskRunner(this, i, _stopTask)); StealTaskRunnerPtr runner(new StealTaskRunner(this, i, _stopTask));
_runnerList.push_back(runner); _runnerList.push_back(runner);
} }
for (auto i = 0; i < m_runnerNum; i++) {
m_runner_queues.push_back(std::make_shared<StealRunnerQueue>());
}
start(); start();
} }
@ -63,7 +68,7 @@ StealTaskGroup::~StealTaskGroup() {}
StealTaskRunnerPtr StealTaskGroup::getRunner(size_t id) { StealTaskRunnerPtr StealTaskGroup::getRunner(size_t id) {
StealTaskRunnerPtr result; StealTaskRunnerPtr result;
if (id >= _groupSize) { if (id >= m_runnerNum) {
std::cerr << "[StealTaskGroup::getRunner] Invalid id: " << id << std::endl; std::cerr << "[StealTaskGroup::getRunner] Invalid id: " << id << std::endl;
return result; return result;
} }
@ -74,17 +79,23 @@ StealTaskRunnerPtr StealTaskGroup::getRunner(size_t id) {
StealTaskRunnerPtr StealTaskGroup::getCurrentRunner() { StealTaskRunnerPtr StealTaskGroup::getCurrentRunner() {
StealTaskRunnerPtr result = _runnerList[_currentRunnerId]; StealTaskRunnerPtr result = _runnerList[_currentRunnerId];
_currentRunnerId++; _currentRunnerId++;
if (_currentRunnerId >= _groupSize) { if (_currentRunnerId >= m_runnerNum) {
_currentRunnerId = 0; _currentRunnerId = 0;
} }
return result; return result;
} }
void StealTaskGroup::addTask(const StealTaskPtr& task) { StealTaskPtr StealTaskGroup::addTask(const StealTaskPtr& task) {
if (StealTaskRunner::m_local_queue) {
HKU_INFO("add task to local queue!");
StealTaskRunner::m_local_queue->push_front(task);
}
_taskList.push_back(task); _taskList.push_back(task);
StealTaskRunnerPtr runner = getCurrentRunner(); StealTaskRunnerPtr runner = getCurrentRunner();
task->setTaskRunner(runner.get()); task->setTaskRunner(runner.get());
runner->putTask(task); runner->putTask(task);
return task;
} }
void StealTaskGroup::start() { void StealTaskGroup::start() {

View File

@ -10,6 +10,7 @@
#define STEALTASKGROUP_H_ #define STEALTASKGROUP_H_
#include "StealTaskRunner.h" #include "StealTaskRunner.h"
#include "StealRunnerQueue.h"
namespace hku { namespace hku {
@ -18,6 +19,8 @@ namespace hku {
* @ingroup TaskGroup * @ingroup TaskGroup
*/ */
class HKU_API StealTaskGroup { class HKU_API StealTaskGroup {
friend class StealTaskRunner;
public: public:
/** /**
* *
@ -26,16 +29,24 @@ public:
* @return * @return
*/ */
StealTaskGroup(size_t taskCount = 3072, size_t groupSize = 0); StealTaskGroup(size_t taskCount = 3072, size_t groupSize = 0);
/**
*
*/
virtual ~StealTaskGroup(); virtual ~StealTaskGroup();
/**
* 线
*/
size_t size() const { size_t size() const {
return _groupSize; return m_runnerNum;
} }
StealTaskRunnerPtr getRunner(size_t id); StealTaskRunnerPtr getRunner(size_t id);
StealTaskRunnerPtr getCurrentRunner(); StealTaskRunnerPtr getCurrentRunner();
//增加一个任务 //增加一个任务
void addTask(const StealTaskPtr& task); StealTaskPtr addTask(const StealTaskPtr& task);
void start(); void start();
@ -58,9 +69,11 @@ private:
typedef std::vector<StealTaskRunnerPtr> RunnerList; typedef std::vector<StealTaskRunnerPtr> RunnerList;
RunnerList _runnerList; RunnerList _runnerList;
StealTaskList _taskList; StealTaskList _taskList;
size_t _groupSize; size_t m_runnerNum;
StealTaskPtr _stopTask; StealTaskPtr _stopTask;
size_t _currentRunnerId; //记录当前执行addTask任务时需放入的TaskRunnerid用于均衡任务分配 size_t _currentRunnerId; //记录当前执行addTask任务时需放入的TaskRunnerid用于均衡任务分配
std::vector<std::shared_ptr<StealRunnerQueue>> m_runner_queues; // 任务队列(每个工作线程一个)
}; };
typedef std::shared_ptr<StealTaskGroup> StealTaskGroupPtr; typedef std::shared_ptr<StealTaskGroup> StealTaskGroupPtr;

View File

@ -11,28 +11,26 @@
#include <functional> #include <functional>
#include <iostream> #include <iostream>
#include "../../Log.h"
#include "StealTaskRunner.h" #include "StealTaskRunner.h"
#include "StealTaskGroup.h" #include "StealTaskGroup.h"
#define QUEUE_LOCK std::lock_guard<std::mutex> lock(_mutex); #define QUEUE_LOCK std::lock_guard<std::mutex> lock(m_queue_mutex);
namespace hku { namespace hku {
StealTaskRunner::StealTaskRunner(StealTaskGroup* group, size_t id, StealTaskPtr stopTask) { StealTaskRunner::StealTaskRunner(StealTaskGroup* group, size_t id, StealTaskPtr stopTask) {
_id = id; m_index = id;
_group = group; m_group = group;
_stopTask = stopTask; _stopTask = stopTask;
} }
StealTaskRunner::~StealTaskRunner() {} StealTaskRunner::~StealTaskRunner() {}
/** // 加入一个普通任务,将其放入私有队列的后端
*
* @param task
*/
void StealTaskRunner::putTask(const StealTaskPtr& task) { void StealTaskRunner::putTask(const StealTaskPtr& task) {
QUEUE_LOCK; QUEUE_LOCK;
_queue.push_back(task); m_queue.push_back(task);
} }
/** /**
@ -41,7 +39,7 @@ void StealTaskRunner::putTask(const StealTaskPtr& task) {
*/ */
void StealTaskRunner::putWatchTask(const StealTaskPtr& task) { void StealTaskRunner::putWatchTask(const StealTaskPtr& task) {
QUEUE_LOCK; QUEUE_LOCK;
_queue.push_front(task); m_queue.push_front(task);
} }
/** /**
@ -51,9 +49,9 @@ void StealTaskRunner::putWatchTask(const StealTaskPtr& task) {
StealTaskPtr StealTaskRunner::takeTaskBySelf() { StealTaskPtr StealTaskRunner::takeTaskBySelf() {
QUEUE_LOCK; QUEUE_LOCK;
StealTaskPtr result; StealTaskPtr result;
if (!_queue.empty()) { if (!m_queue.empty()) {
result = _queue.back(); result = m_queue.back();
_queue.pop_back(); m_queue.pop_back();
} }
return result; return result;
@ -66,12 +64,12 @@ StealTaskPtr StealTaskRunner::takeTaskBySelf() {
StealTaskPtr StealTaskRunner::takeTaskByOther() { StealTaskPtr StealTaskRunner::takeTaskByOther() {
QUEUE_LOCK; QUEUE_LOCK;
StealTaskPtr result; StealTaskPtr result;
if (!_queue.empty()) { if (!m_queue.empty()) {
StealTaskPtr front = _queue.front(); StealTaskPtr front = m_queue.front();
//如果提取的任务是停止任务,则放弃并返回空 //如果提取的任务是停止任务,则放弃并返回空
if (front != _stopTask) { if (front != _stopTask) {
result = front; result = front;
_queue.pop_front(); m_queue.pop_front();
} }
} }
@ -96,6 +94,9 @@ void StealTaskRunner::join() {
* *
*/ */
void StealTaskRunner::run() { void StealTaskRunner::run() {
m_local_queue = m_group->m_runner_queues[m_index].get();
m_local_index = m_index;
m_locla_need_stop = false;
StealTaskPtr task; StealTaskPtr task;
try { try {
while (task != _stopTask) { while (task != _stopTask) {
@ -109,6 +110,7 @@ void StealTaskRunner::run() {
steal(StealTaskPtr()); steal(StealTaskPtr());
} }
} }
HKU_INFO("{} local size: {}", std::this_thread::get_id(), m_local_queue->size());
} catch (...) { } catch (...) {
std::cerr << "[TaskRunner::run] Some error!" << std::endl; std::cerr << "[TaskRunner::run] Some error!" << std::endl;
@ -116,7 +118,7 @@ void StealTaskRunner::run() {
} }
/** /**
* *
* @param waitingFor - * @param waitingFor -
*/ */
void StealTaskRunner::taskJoin(const StealTaskPtr& waitingFor) { void StealTaskRunner::taskJoin(const StealTaskPtr& waitingFor) {
@ -145,10 +147,10 @@ void StealTaskRunner::steal(const StealTaskPtr& waitingFor) {
std::srand(temp); std::srand(temp);
#endif #endif
size_t total = _group->size(); size_t total = m_group->size();
size_t ran_num = std::rand() % total; size_t ran_num = std::rand() % total;
for (size_t i = 0; i < total; i++) { for (size_t i = 0; i < total; i++) {
StealTaskRunnerPtr tr = _group->getRunner(ran_num); StealTaskRunnerPtr tr = m_group->getRunner(ran_num);
if (waitingFor && waitingFor->isDone()) { if (waitingFor && waitingFor->isDone()) {
break; break;
} }

View File

@ -13,6 +13,7 @@
#include <deque> #include <deque>
#include <list> #include <list>
#include "StealTaskBase.h" #include "StealTaskBase.h"
#include "StealRunnerQueue.h"
namespace hku { namespace hku {
@ -22,12 +23,21 @@ class StealTaskGroup;
* *
* @ingroup TaskGroup * @ingroup TaskGroup
*/ */
class HKU_API StealTaskRunner { class StealTaskRunner {
friend class StealTaskGroup;
friend class StealTaskBase;
public: public:
StealTaskRunner(StealTaskGroup* group, size_t id, StealTaskPtr stopTask); StealTaskRunner(StealTaskGroup* group, size_t id, StealTaskPtr stopTask);
virtual ~StealTaskRunner(); virtual ~StealTaskRunner();
private:
/**
*
* @param task
*/
void putTask(const StealTaskPtr&); void putTask(const StealTaskPtr&);
void putWatchTask(const StealTaskPtr&); void putWatchTask(const StealTaskPtr&);
StealTaskPtr takeTaskBySelf(); StealTaskPtr takeTaskBySelf();
@ -40,19 +50,24 @@ public:
void taskJoin(const StealTaskPtr& waitingFor); void taskJoin(const StealTaskPtr& waitingFor);
StealTaskGroup* getTaskRunnerGroup() { StealTaskGroup* getTaskRunnerGroup() {
return _group; return m_group;
} }
private: private:
size_t _id; size_t m_index; // 表示在任务组中的第几个线程
StealTaskGroup* _group; StealTaskGroup* m_group; // 所属任务组的指针
StealTaskPtr _stopTask; StealTaskPtr _stopTask;
std::thread m_thread; inline static thread_local StealRunnerQueue* m_local_queue = nullptr; //本地任务队列
std::mutex _mutex; inline static thread_local size_t m_local_index = 0; // 在任务组中的序号(m_index)
inline static thread_local bool m_locla_need_stop = false; // 线程停止运行指示
typedef std::list<StealTaskPtr> Queue; std::thread m_thread; // 本地工作线程
Queue _queue;
// 线程内工作任务队列
std::mutex m_queue_mutex;
typedef std::deque<StealTaskPtr> Queue;
Queue m_queue;
}; };
typedef std::shared_ptr<StealTaskRunner> StealTaskRunnerPtr; typedef std::shared_ptr<StealTaskRunner> StealTaskRunnerPtr;

View File

@ -0,0 +1,43 @@
/*
* test_Parameter.cpp
*
* Created on: 2020-4-26
* Author: fasiondog
*/
#include "doctest/doctest.h"
#include <hikyuu/Log.h>
#include <hikyuu/utilities/task/StealTaskBase.h>
#include <hikyuu/utilities/task/StealTaskGroup.h>
using namespace hku;
/**
* @defgroup test_hikyuu_TaskGroup test_hikyuu_TaskGroup
* @ingroup test_hikyuu_utilities
* @{
*/
class TestTask : public TaskBase {
public:
TestTask(int i) : m_i(i) {}
virtual ~TestTask() = default;
virtual void run() {
HKU_INFO("{}: *****************", m_i);
}
private:
int m_i;
};
/** @par 检测点 */
TEST_CASE("test_TaskGroup") {
TaskGroup tg;
for (int i = 0; i < 30; i++) {
tg.addTask(std::make_shared<TestTask>(i));
}
tg.run();
}
/** @} */

View File

@ -85,7 +85,8 @@ target("small-test")
end end
-- add files -- add files
add_files("./hikyuu/hikyuu/**.cpp"); add_files("./hikyuu/utilities/test_TaskGroup.cpp");
--add_files("./hikyuu/hikyuu/**.cpp");
--add_files("./hikyuu/hikyuu/test_StockManager.cpp"); --add_files("./hikyuu/hikyuu/test_StockManager.cpp");
add_files("./hikyuu/test_main.cpp") add_files("./hikyuu/test_main.cpp")