优化StrategyBase,以便支持strategy级别回测

This commit is contained in:
fasiondog 2024-05-07 17:48:46 +08:00
parent a766b14192
commit c63f9d6360
4 changed files with 53 additions and 28 deletions

View File

@ -31,6 +31,7 @@ public:
}
void startDatetime(const Datetime& d) {
HKU_CHECK(!d.isNull(), "Don't use null datetime!");
m_startDatetime = d;
}

View File

@ -58,7 +58,7 @@ void StrategyBase::_initDefaultParam() {
setParam<bool>("enable_2hour_clock", false);
}
void StrategyBase::run() {
void StrategyBase::_run(bool forTest) {
// 调用 strategy 自身的初始化方法
init();
@ -127,15 +127,10 @@ void StrategyBase::run() {
ktype_list.push_back(KQuery::DAY);
}
// 不使用默认的预加载模式
for (auto ktype : ktype_list) {
to_lower(ktype);
preloadParam.set<bool>(ktype, true);
string key(format("{}_max", ktype));
try {
preloadParam.set<int>(key, config.getInt("preload", key));
} catch (...) {
preloadParam.set<int>(key, 4096);
}
preloadParam.set<bool>(ktype, false);
}
sm.init(baseParam, blockParam, kdataParam, preloadParam, hkuParam, m_context);
@ -152,29 +147,54 @@ void StrategyBase::run() {
}
HKU_WARN_IF(m_stock_list.empty(), "[Strategy {}] stock list is empty!", m_name);
if (m_stock_list.size() > 0) {
const Stock& ref_stk = m_stock_list[0];
for (const auto& ktype : ktype_list) {
// 由于异步初始化此处不用通过先getCount再getKRecord的方式获取最后的KRecord
KRecordList klist = ref_stk.getKRecordList(KQueryByIndex(0, Null<int64_t>(), ktype));
size_t count = klist.size();
if (count > 0) {
m_ref_last_time[ktype] = klist[count - 1].datetime;
} else {
m_ref_last_time[ktype] = Null<Datetime>();
}
// 借助 Stock.setKRecordList 方法进行预加载(同步方式,不需要异步加载)
// 只从 context 指定起始日期开始加载
size_t ktype_count = ktype_list.size();
vector<KRecordList> k_buffer(ktype_count);
for (auto& stk : m_stock_list) {
// 保留原始 KDataDriver因为使用 stock.setKRecordList 将会把 stock 的 KDataDriver 设置为
// DoNothing
auto old_driver = stk.getKDataDirver();
for (size_t i = 0; i < ktype_count; i++) {
k_buffer[i] = std::move(stk.getKRecordList(
KQueryByDate(m_context.startDatetime(), Null<Datetime>(), ktype_list[i])));
}
for (size_t i = 0; i < ktype_count; i++) {
stk.setKRecordList(std::move(k_buffer[i]), ktype_list[i]);
}
// 恢复 KDataDriver
stk.setKDataDriver(old_driver);
}
// 启动行情接收代理
auto& agent = *getGlobalSpotAgent();
agent.addProcess([this](const SpotRecord& spot) { this->receivedSpot(spot); });
agent.addPostProcess([this](Datetime revTime) { this->finishReceivedSpot(revTime); });
startSpotAgent(false);
// 计算每个类型当前最后的日期
for (const auto& ktype : ktype_list) {
Datetime last_date = Datetime::min();
for (auto& stk : m_stock_list) {
size_t count = stk.getCount(ktype);
if (count > 1) {
auto kr = stk.getKRecord(count - 1, ktype);
if (kr.datetime > last_date) {
last_date = kr.datetime;
}
}
}
m_ref_last_time[ktype] = last_date == Datetime::min() ? Null<Datetime>() : last_date;
}
_addTimer();
if (!forTest) {
// 启动行情接收代理
auto& agent = *getGlobalSpotAgent();
agent.addProcess([this](const SpotRecord& spot) { this->receivedSpot(spot); });
agent.addPostProcess([this](Datetime revTime) { this->finishReceivedSpot(revTime); });
startSpotAgent(false);
_startEventLoop();
_addTimer();
HKU_INFO("start even loop ...");
_startEventLoop();
}
}
void StrategyBase::receivedSpot(const SpotRecord& spot) {

View File

@ -84,7 +84,9 @@ public:
return m_context.getKTypeList();
}
void run();
void run() {
_run(false);
}
void receivedSpot(const SpotRecord& spot);
void finishReceivedSpot(Datetime revTime);
@ -114,6 +116,8 @@ private:
void _addClockEvent(const string& enable, TimeDelta delta, TimeDelta openTime,
TimeDelta closeTime);
void _run(bool forTest);
private:
static std::atomic_bool ms_keep_running;
static void sig_handler(int sig);

View File

@ -42,7 +42,7 @@ public:
};
void export_Strategy(py::module& m) {
py::class_<StrategyBase, PyStrategyBase>(m, "StrategyBase")
py::class_<StrategyBase, StrategyPtr, PyStrategyBase>(m, "StrategyBase")
.def(py::init<>())
.def_property("name", py::overload_cast<>(&StrategyBase::name, py::const_),
py::overload_cast<const string&>(&StrategyBase::name),