From 352badc28c964fe5cfdcd7d24a28602af4043805 Mon Sep 17 00:00:00 2001 From: fasiondog Date: Mon, 16 Sep 2024 21:36:20 +0800 Subject: [PATCH] =?UTF-8?q?=E8=B0=83=E6=95=B4=20OptimalSelector=20?= =?UTF-8?q?=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../selector/imp/OptimalSelector.cpp | 120 ++++++++++-------- .../trade_sys/selector/imp/OptimalSelector.h | 7 +- .../trade_sys/selector/test_SE_Optimal.cpp | 2 +- 3 files changed, 71 insertions(+), 58 deletions(-) diff --git a/hikyuu_cpp/hikyuu/trade_sys/selector/imp/OptimalSelector.cpp b/hikyuu_cpp/hikyuu/trade_sys/selector/imp/OptimalSelector.cpp index f6e80028..5ca4e516 100644 --- a/hikyuu_cpp/hikyuu/trade_sys/selector/imp/OptimalSelector.cpp +++ b/hikyuu_cpp/hikyuu/trade_sys/selector/imp/OptimalSelector.cpp @@ -77,9 +77,6 @@ void OptimalSelector::calculate(const SystemList& pf_realSysList, const KQuery& bool trace = getParam("trace"); CLS_INFO_IF(trace, "candidate sys list size: {}", m_pro_sys_list.size()); - // for (const auto& sys : m_pro_sys_list) { - // HKU_DEBUG_IF(trace, "[SE_Optimal] {}", sys->name()); - // } size_t train_len = static_cast(getParam("train_len")); size_t test_len = static_cast(getParam("test_len")); @@ -102,18 +99,27 @@ void OptimalSelector::calculate(const SystemList& pf_realSysList, const KQuery& string key = getParam("key"); int mode = getParam("mode"); - CLS_INFO_IF(trace, "statistic key: {}, mode: {}", key, mode); + CLS_INFO_IF(trace, "statistic key: {}, mode: {}", getParam("key"), + getParam("mode")); if (getParam("parallel")) { - _calculate_parallel(train_ranges, dates); - return; + _calculate_parallel(train_ranges, dates, key, mode, test_len, trace); + } else { + _calculate_single(train_ranges, dates, key, mode, test_len, trace); } + m_calculated = true; +} + +void OptimalSelector::_calculate_single(const vector>& train_ranges, + const DatetimeList& dates, const string& key, int mode, + size_t test_len, bool trace) { + size_t dates_len = dates.size(); Performance per; for (size_t i = 0, total = train_ranges.size(); i < total; i++) { Datetime start_date = dates[train_ranges[i].first]; Datetime end_date = dates[train_ranges[i].second]; - KQuery q = KQueryByDate(start_date, end_date, query.kType(), query.recoverType()); + KQuery q = KQueryByDate(start_date, end_date, m_query.kType(), m_query.recoverType()); CLS_INFO_IF(trace, "iteration: {}|{}, range: {}", i + 1, total, q); SYSPtr selected_sys; if (m_pro_sys_list.size() == 1) { @@ -144,37 +150,38 @@ void OptimalSelector::calculate(const SystemList& pf_realSysList, const KQuery& } } - HKU_ASSERT(selected_sys); - size_t test_start = train_ranges[i].second; - size_t test_end = test_start + test_len; - if (test_end > dates_len) { - test_end = dates_len; - } - selected_sys->reset(); - selected_sys = selected_sys->clone(); - for (size_t pos = test_start; pos < test_end; pos++) { - m_sys_dict[dates[pos]] = selected_sys; - } - if (test_end < dates_len) { - m_run_ranges.emplace_back(std::make_pair(dates[test_start], dates[test_end])); - } else { - m_run_ranges.emplace_back( - std::make_pair(dates[test_start], dates[test_end - 1] + Seconds(1))); - } - CLS_INFO_IF(trace, "iteration: {}, selected_sys: {}", i + 1, selected_sys->name()); - } + if (selected_sys) { + selected_sys->reset(); + selected_sys = selected_sys->clone(); - m_calculated = true; + size_t test_start = train_ranges[i].second; + size_t test_end = test_start + test_len; + if (test_end > dates_len) { + test_end = dates_len; + } + + for (size_t pos = test_start; pos < test_end; pos++) { + m_sys_dict[dates[pos]] = selected_sys; + } + + if (test_end < dates_len) { + m_run_ranges.emplace_back(std::make_pair(dates[test_start], dates[test_end])); + } else { + m_run_ranges.emplace_back( + std::make_pair(dates[test_start], dates[test_end - 1] + Seconds(1))); + } + + CLS_INFO_IF(trace, "iteration: {}, selected_sys: {}", i + 1, selected_sys->name()); + } + } } void OptimalSelector::_calculate_parallel(const vector>& train_ranges, - const DatetimeList& dates) { + const DatetimeList& dates, const string& key, int mode, + size_t test_len, bool trace) { auto sys_list = parallel_for_index( - 0, train_ranges.size(), [this, &train_ranges, &dates, query = m_query](size_t i) { - bool trace = getParam("trace"); - string key = getParam("key"); - int mode = getParam("mode"); - + 0, train_ranges.size(), + [this, &train_ranges, &dates, query = m_query, trace, key, mode](size_t i) { Datetime start_date = dates[train_ranges[i].first]; Datetime end_date = dates[train_ranges[i].second]; KQuery q = KQueryByDate(start_date, end_date, query.kType(), query.recoverType()); @@ -210,35 +217,36 @@ void OptimalSelector::_calculate_parallel(const vector } } - HKU_IF_RETURN(!selected_sys, selected_sys); - - selected_sys->reset(); - return selected_sys->clone(); + return selected_sys; }); - // size_t train_len = static_cast(getParam("train_len")); - bool trace = getParam("trace"); size_t dates_len = dates.size(); - size_t test_len = static_cast(getParam("test_len")); for (size_t i = 0, total = train_ranges.size(); i < total; i++) { - size_t test_start = train_ranges[i].second; - size_t test_end = test_start + test_len; - if (test_end > dates_len) { - test_end = dates_len; + auto& selected_sys = sys_list[i]; + if (selected_sys) { + selected_sys->reset(); + selected_sys = selected_sys->clone(); + + size_t test_start = train_ranges[i].second; + size_t test_end = test_start + test_len; + if (test_end > dates_len) { + test_end = dates_len; + } + + for (size_t pos = test_start; pos < test_end; pos++) { + m_sys_dict[dates[pos]] = selected_sys; + } + + if (test_end < dates_len) { + m_run_ranges.emplace_back(std::make_pair(dates[test_start], dates[test_end])); + } else { + m_run_ranges.emplace_back( + std::make_pair(dates[test_start], dates[test_end - 1] + Seconds(1))); + } + + CLS_INFO_IF(trace, "iteration: {}, selected_sys: {}", i + 1, selected_sys->name()); } - const auto& selected_sys = sys_list[i]; - for (size_t pos = test_start; pos < test_end; pos++) { - m_sys_dict[dates[pos]] = selected_sys; - } - if (test_end < dates_len) { - m_run_ranges.emplace_back(std::make_pair(dates[test_start], dates[test_end])); - } else { - m_run_ranges.emplace_back( - std::make_pair(dates[test_start], dates[test_end - 1] + Seconds(1))); - } - CLS_INFO_IF(trace, "iteration: {}, selected_sys: {}", i + 1, selected_sys->name()); } - m_calculated = true; } SEPtr HKU_API SE_Optimal() { diff --git a/hikyuu_cpp/hikyuu/trade_sys/selector/imp/OptimalSelector.h b/hikyuu_cpp/hikyuu/trade_sys/selector/imp/OptimalSelector.h index bd6e570e..f17093a0 100644 --- a/hikyuu_cpp/hikyuu/trade_sys/selector/imp/OptimalSelector.h +++ b/hikyuu_cpp/hikyuu/trade_sys/selector/imp/OptimalSelector.h @@ -30,8 +30,13 @@ public: } private: + void _calculate_single(const vector>& train_ranges, + const DatetimeList& dates, const string& key, int mode, size_t test_len, + bool trace); + void _calculate_parallel(const vector>& train_ranges, - const DatetimeList& dates); + const DatetimeList& dates, const string& key, int mode, + size_t test_len, bool trace); private: unordered_map m_sys_dict; diff --git a/hikyuu_cpp/unit_test/hikyuu/trade_sys/selector/test_SE_Optimal.cpp b/hikyuu_cpp/unit_test/hikyuu/trade_sys/selector/test_SE_Optimal.cpp index e13acbf5..7f73fdec 100644 --- a/hikyuu_cpp/unit_test/hikyuu/trade_sys/selector/test_SE_Optimal.cpp +++ b/hikyuu_cpp/unit_test/hikyuu/trade_sys/selector/test_SE_Optimal.cpp @@ -123,7 +123,7 @@ TEST_CASE("test_SE_Optimal") { query = KQueryByIndex(-125); // se->setParam("trace", true); - se->setParam("parallel", true); + // se->setParam("parallel", true); se->setParam("train_len", 30); se->setParam("test_len", 20); se->calculate(SystemList(), query);