补充 SYS_WalkForword 测试及相关修正

This commit is contained in:
KongDong 2024-11-03 01:31:23 +08:00
parent cafabae9ce
commit 556a206193
4 changed files with 198 additions and 15 deletions

View File

@ -139,10 +139,12 @@ void OptimalSelectorBase::_calculate_single(const vector<std::pair<size_t, size_
selected_sys_list = std::make_shared<SystemWeightList>();
for (const auto& sys : m_pro_sys_list) {
try {
sys->run(q, true);
double value = evaluate(sys, end_date);
auto nsys = sys->clone();
nsys->run(q, true);
double value = evaluate(nsys, end_date);
nsys->reset();
if (!std::isnan(value)) {
selected_sys_list->emplace_back(SystemWeight(sys, value));
selected_sys_list->emplace_back(SystemWeight(nsys, value));
}
} catch (const std::exception& e) {
CLS_ERROR("{}! {}", e.what(), sys->name());
@ -194,11 +196,11 @@ void OptimalSelectorBase::_calculate_parallel(const vector<std::pair<size_t, siz
auto selected_sys_list = std::make_shared<SystemWeightList>();
for (const auto& sys : m_pro_sys_list) {
try {
auto nsys = sys->clone();
nsys->run(q, true);
double value = evaluate(nsys, end_date);
sys->run(q, true);
double value = evaluate(sys, end_date);
sys->reset();
if (!std::isnan(value)) {
selected_sys_list->emplace_back(SystemWeight(nsys, value));
selected_sys_list->emplace_back(SystemWeight(sys->clone(), value));
}
} catch (const std::exception& e) {
CLS_ERROR("{}! {}", e.what(), sys->name());

View File

@ -140,6 +140,9 @@ void PerformanceOptimalSelector::_calculate_single(
}
if (selected_sys) {
selected_sys->reset();
selected_sys = selected_sys->clone();
size_t train_start = train_ranges[i].first;
size_t test_start = train_ranges[i].second;
size_t test_end = test_start + test_len;
@ -180,7 +183,7 @@ void PerformanceOptimalSelector::_calculate_parallel(
Performance per;
SYSPtr selected_sys;
if (m_pro_sys_list.size() == 1) {
selected_sys = m_pro_sys_list.back();
selected_sys = m_pro_sys_list.back()->clone();
} else if (0 == mode) {
double max_value = std::numeric_limits<double>::lowest();
for (const auto& sys : m_pro_sys_list) {
@ -217,6 +220,8 @@ void PerformanceOptimalSelector::_calculate_parallel(
for (size_t i = 0, total = train_ranges.size(); i < total; i++) {
auto& selected_sys = sys_list[i];
if (selected_sys) {
selected_sys->reset();
size_t train_start = train_ranges[i].first;
size_t test_start = train_ranges[i].second;
size_t test_end = test_start + test_len;

View File

@ -15,6 +15,7 @@ namespace hku {
SystemPtr HKU_API SYS_WalkForward(const SystemList& candidate_sys_list,
const TradeManagerPtr& tm = TradeManagerPtr(),
size_t train_len = 100, size_t test_len = 20,
// const SelectorPtr& se = SE_MaxFundsOptimal(),
const SelectorPtr& se = SE_PerformanceOptimal(),
const TradeManagerPtr& train_tm = TradeManagerPtr());

View File

@ -33,7 +33,7 @@ static SYSPtr create_test_sys(int fast_n, int slow_n) {
}
/** @par 检测点 */
TEST_CASE("test_SYS_WalkForword") {
TEST_CASE("test_SYS_WalkForword_SE_MaxFundsOptimal_not_parallel") {
Stock stk = getStock("sz000001");
KQuery query = KQueryByIndex(-50);
TMPtr tm = crtTM();
@ -49,11 +49,12 @@ TEST_CASE("test_SYS_WalkForword") {
auto sys = SYS_WalkForward(SystemList{create_test_sys(3, 5)}, tm, 30, 20);
CHECK_THROWS(sys->run(query));
/** @arg 只有一个候选系统 */
sys = SYS_WalkForward(SystemList{create_test_sys(3, 5)}, tm, 30, 20);
/** @arg 只有一个候选系统, 使用 SE_MaxFundsOptimal */
auto se = SE_MaxFundsOptimal();
sys = SYS_WalkForward(SystemList{create_test_sys(3, 5)}, tm, 30, 20, se);
CHECK_EQ(sys->name(), "SYS_WalkForward");
query = KQueryByIndex(-125);
// sys->setParam<bool>("trace", true);
sys->setParam<bool>("parallel", false);
sys->run(stk, query);
auto delay_request = sys->getBuyTradeRequest();
@ -71,7 +72,7 @@ TEST_CASE("test_SYS_WalkForword") {
CHECK_EQ(tr_list1[i + 1], tr_list2[i]);
}
/** @arg 多个候选系统 */
/** @arg 多个候选系统 使用 SE_MaxFundsOptimal */
vector<std::pair<int, int>> params{{3, 5}, {3, 10}, {5, 10}, {5, 20}};
SystemList sys_list;
for (const auto& param : params) {
@ -79,9 +80,183 @@ TEST_CASE("test_SYS_WalkForword") {
}
tm->reset();
REQUIRE(tm->getTradeList().size() == 1);
sys = SYS_WalkForward(sys_list, tm, 30, 20);
sys = SYS_WalkForward(sys_list, tm, 30, 20, se);
query = KQueryByIndex(-125);
// sys->setParam<bool>("trace", true);
sys->setParam<bool>("parallel", false);
sys->run(stk, query);
delay_request = sys->getBuyTradeRequest();
CHECK_UNARY(delay_request.valid);
CHECK_EQ(delay_request.business, BUSINESS_BUY);
CHECK_EQ(delay_request.datetime, Datetime(20111205));
tm = sys->getTM();
CHECK_EQ(tm->currentCash(), 98286.0);
tr_list1 = tm->getTradeList();
tr_list2 = sys->getTradeRecordList();
CHECK_EQ(tr_list1.size(), tr_list2.size() + 1);
for (size_t i = 0, total = tr_list2.size(); i < total; i++) {
CHECK_EQ(tr_list1[i + 1], tr_list2[i]);
}
}
/** @par 检测点 */
TEST_CASE("test_SYS_WalkForword_SE_MaxFundsOptimal_parallel") {
Stock stk = getStock("sz000001");
KQuery query = KQueryByIndex(-50);
TMPtr tm = crtTM();
/** @arg 只有一个候选系统, 使用 SE_MaxFundsOptimal */
auto se = SE_MaxFundsOptimal();
auto sys = SYS_WalkForward(SystemList{create_test_sys(3, 5)}, tm, 30, 20, se);
CHECK_EQ(sys->name(), "SYS_WalkForward");
query = KQueryByIndex(-125);
sys->setParam<bool>("parallel", true);
sys->run(stk, query);
auto delay_request = sys->getBuyTradeRequest();
CHECK_UNARY(delay_request.valid);
CHECK_EQ(delay_request.business, BUSINESS_BUY);
CHECK_EQ(delay_request.datetime, Datetime(20111205));
tm = sys->getTM();
CHECK_EQ(tm->currentCash(), 99328.0);
auto tr_list1 = tm->getTradeList();
auto tr_list2 = sys->getTradeRecordList();
CHECK_EQ(tr_list1.size(), tr_list2.size() + 1);
for (size_t i = 0, total = tr_list2.size(); i < total; i++) {
CHECK_EQ(tr_list1[i + 1], tr_list2[i]);
}
/** @arg 多个候选系统, 使用 SE_MaxFundsOptimal */
vector<std::pair<int, int>> params{{3, 5}, {3, 10}, {5, 10}, {5, 20}};
SystemList sys_list;
for (const auto& param : params) {
sys_list.emplace_back(create_test_sys(param.first, param.second));
}
tm->reset();
REQUIRE(tm->getTradeList().size() == 1);
sys = SYS_WalkForward(sys_list, tm, 30, 20, se);
query = KQueryByIndex(-125);
sys->setParam<bool>("parallel", true);
sys->run(stk, query);
delay_request = sys->getBuyTradeRequest();
CHECK_UNARY(delay_request.valid);
CHECK_EQ(delay_request.business, BUSINESS_BUY);
CHECK_EQ(delay_request.datetime, Datetime(20111205));
tm = sys->getTM();
CHECK_EQ(tm->currentCash(), 98286.0);
tr_list1 = tm->getTradeList();
tr_list2 = sys->getTradeRecordList();
CHECK_EQ(tr_list1.size(), tr_list2.size() + 1);
for (size_t i = 0, total = tr_list2.size(); i < total; i++) {
CHECK_EQ(tr_list1[i + 1], tr_list2[i]);
}
}
/** @par 检测点 */
TEST_CASE("test_SYS_WalkForword_SE_PerformanceOptimal_not_parallel") {
Stock stk = getStock("sz000001");
KQuery query = KQueryByIndex(-50);
TMPtr tm = crtTM();
/** @arg 只有一个候选系统, 使用 SE_MaxFundsOptimal */
auto se = SE_PerformanceOptimal("当前总资产");
auto sys = SYS_WalkForward(SystemList{create_test_sys(3, 5)}, tm, 30, 20, se);
CHECK_EQ(sys->name(), "SYS_WalkForward");
query = KQueryByIndex(-125);
sys->setParam<bool>("parallel", false);
sys->run(stk, query);
auto delay_request = sys->getBuyTradeRequest();
CHECK_UNARY(delay_request.valid);
CHECK_EQ(delay_request.business, BUSINESS_BUY);
CHECK_EQ(delay_request.datetime, Datetime(20111205));
tm = sys->getTM();
CHECK_EQ(tm->currentCash(), 99328.0);
auto tr_list1 = tm->getTradeList();
auto tr_list2 = sys->getTradeRecordList();
CHECK_EQ(tr_list1.size(), tr_list2.size() + 1);
for (size_t i = 0, total = tr_list2.size(); i < total; i++) {
CHECK_EQ(tr_list1[i + 1], tr_list2[i]);
}
/** @arg 多个候选系统, 使用 SE_MaxFundsOptimal */
vector<std::pair<int, int>> params{{3, 5}, {3, 10}, {5, 10}, {5, 20}};
SystemList sys_list;
for (const auto& param : params) {
sys_list.emplace_back(create_test_sys(param.first, param.second));
}
tm->reset();
REQUIRE(tm->getTradeList().size() == 1);
sys = SYS_WalkForward(sys_list, tm, 30, 20, se);
query = KQueryByIndex(-125);
sys->setParam<bool>("parallel", false);
sys->run(stk, query);
delay_request = sys->getBuyTradeRequest();
CHECK_UNARY(delay_request.valid);
CHECK_EQ(delay_request.business, BUSINESS_BUY);
CHECK_EQ(delay_request.datetime, Datetime(20111205));
tm = sys->getTM();
CHECK_EQ(tm->currentCash(), 98286.0);
tr_list1 = tm->getTradeList();
tr_list2 = sys->getTradeRecordList();
CHECK_EQ(tr_list1.size(), tr_list2.size() + 1);
for (size_t i = 0, total = tr_list2.size(); i < total; i++) {
CHECK_EQ(tr_list1[i + 1], tr_list2[i]);
}
}
/** @par 检测点 */
TEST_CASE("test_SYS_WalkForword_SE_PerformanceOptimal_parallel") {
Stock stk = getStock("sz000001");
KQuery query = KQueryByIndex(-50);
TMPtr tm = crtTM();
/** @arg 只有一个候选系统, 使用 SE_MaxFundsOptimal */
auto se = SE_PerformanceOptimal("当前总资产");
auto sys = SYS_WalkForward(SystemList{create_test_sys(3, 5)}, tm, 30, 20, se);
CHECK_EQ(sys->name(), "SYS_WalkForward");
query = KQueryByIndex(-125);
sys->setParam<bool>("parallel", true);
sys->run(stk, query);
auto delay_request = sys->getBuyTradeRequest();
CHECK_UNARY(delay_request.valid);
CHECK_EQ(delay_request.business, BUSINESS_BUY);
CHECK_EQ(delay_request.datetime, Datetime(20111205));
tm = sys->getTM();
CHECK_EQ(tm->currentCash(), 99328.0);
auto tr_list1 = tm->getTradeList();
auto tr_list2 = sys->getTradeRecordList();
CHECK_EQ(tr_list1.size(), tr_list2.size() + 1);
for (size_t i = 0, total = tr_list2.size(); i < total; i++) {
CHECK_EQ(tr_list1[i + 1], tr_list2[i]);
}
/** @arg 多个候选系统, 使用 SE_MaxFundsOptimal */
vector<std::pair<int, int>> params{{3, 5}, {3, 10}, {5, 10}, {5, 20}};
SystemList sys_list;
for (const auto& param : params) {
sys_list.emplace_back(create_test_sys(param.first, param.second));
}
tm->reset();
REQUIRE(tm->getTradeList().size() == 1);
sys = SYS_WalkForward(sys_list, tm, 30, 20, se);
query = KQueryByIndex(-125);
sys->setParam<bool>("parallel", true);
sys->run(stk, query);
delay_request = sys->getBuyTradeRequest();