mirror of
https://gitee.com/fasiondog/hikyuu.git
synced 2024-11-29 18:39:10 +08:00
补充 SYS_WalkForword 测试及相关修正
This commit is contained in:
parent
cafabae9ce
commit
556a206193
@ -139,10 +139,12 @@ void OptimalSelectorBase::_calculate_single(const vector<std::pair<size_t, size_
|
||||
selected_sys_list = std::make_shared<SystemWeightList>();
|
||||
for (const auto& sys : m_pro_sys_list) {
|
||||
try {
|
||||
sys->run(q, true);
|
||||
double value = evaluate(sys, end_date);
|
||||
auto nsys = sys->clone();
|
||||
nsys->run(q, true);
|
||||
double value = evaluate(nsys, end_date);
|
||||
nsys->reset();
|
||||
if (!std::isnan(value)) {
|
||||
selected_sys_list->emplace_back(SystemWeight(sys, value));
|
||||
selected_sys_list->emplace_back(SystemWeight(nsys, value));
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
CLS_ERROR("{}! {}", e.what(), sys->name());
|
||||
@ -194,11 +196,11 @@ void OptimalSelectorBase::_calculate_parallel(const vector<std::pair<size_t, siz
|
||||
auto selected_sys_list = std::make_shared<SystemWeightList>();
|
||||
for (const auto& sys : m_pro_sys_list) {
|
||||
try {
|
||||
auto nsys = sys->clone();
|
||||
nsys->run(q, true);
|
||||
double value = evaluate(nsys, end_date);
|
||||
sys->run(q, true);
|
||||
double value = evaluate(sys, end_date);
|
||||
sys->reset();
|
||||
if (!std::isnan(value)) {
|
||||
selected_sys_list->emplace_back(SystemWeight(nsys, value));
|
||||
selected_sys_list->emplace_back(SystemWeight(sys->clone(), value));
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
CLS_ERROR("{}! {}", e.what(), sys->name());
|
||||
|
@ -140,6 +140,9 @@ void PerformanceOptimalSelector::_calculate_single(
|
||||
}
|
||||
|
||||
if (selected_sys) {
|
||||
selected_sys->reset();
|
||||
selected_sys = selected_sys->clone();
|
||||
|
||||
size_t train_start = train_ranges[i].first;
|
||||
size_t test_start = train_ranges[i].second;
|
||||
size_t test_end = test_start + test_len;
|
||||
@ -180,7 +183,7 @@ void PerformanceOptimalSelector::_calculate_parallel(
|
||||
Performance per;
|
||||
SYSPtr selected_sys;
|
||||
if (m_pro_sys_list.size() == 1) {
|
||||
selected_sys = m_pro_sys_list.back();
|
||||
selected_sys = m_pro_sys_list.back()->clone();
|
||||
} else if (0 == mode) {
|
||||
double max_value = std::numeric_limits<double>::lowest();
|
||||
for (const auto& sys : m_pro_sys_list) {
|
||||
@ -217,6 +220,8 @@ void PerformanceOptimalSelector::_calculate_parallel(
|
||||
for (size_t i = 0, total = train_ranges.size(); i < total; i++) {
|
||||
auto& selected_sys = sys_list[i];
|
||||
if (selected_sys) {
|
||||
selected_sys->reset();
|
||||
|
||||
size_t train_start = train_ranges[i].first;
|
||||
size_t test_start = train_ranges[i].second;
|
||||
size_t test_end = test_start + test_len;
|
||||
|
@ -15,6 +15,7 @@ namespace hku {
|
||||
SystemPtr HKU_API SYS_WalkForward(const SystemList& candidate_sys_list,
|
||||
const TradeManagerPtr& tm = TradeManagerPtr(),
|
||||
size_t train_len = 100, size_t test_len = 20,
|
||||
// const SelectorPtr& se = SE_MaxFundsOptimal(),
|
||||
const SelectorPtr& se = SE_PerformanceOptimal(),
|
||||
const TradeManagerPtr& train_tm = TradeManagerPtr());
|
||||
|
||||
|
@ -33,7 +33,7 @@ static SYSPtr create_test_sys(int fast_n, int slow_n) {
|
||||
}
|
||||
|
||||
/** @par 检测点 */
|
||||
TEST_CASE("test_SYS_WalkForword") {
|
||||
TEST_CASE("test_SYS_WalkForword_SE_MaxFundsOptimal_not_parallel") {
|
||||
Stock stk = getStock("sz000001");
|
||||
KQuery query = KQueryByIndex(-50);
|
||||
TMPtr tm = crtTM();
|
||||
@ -49,11 +49,12 @@ TEST_CASE("test_SYS_WalkForword") {
|
||||
auto sys = SYS_WalkForward(SystemList{create_test_sys(3, 5)}, tm, 30, 20);
|
||||
CHECK_THROWS(sys->run(query));
|
||||
|
||||
/** @arg 只有一个候选系统 */
|
||||
sys = SYS_WalkForward(SystemList{create_test_sys(3, 5)}, tm, 30, 20);
|
||||
/** @arg 只有一个候选系统, 使用 SE_MaxFundsOptimal */
|
||||
auto se = SE_MaxFundsOptimal();
|
||||
sys = SYS_WalkForward(SystemList{create_test_sys(3, 5)}, tm, 30, 20, se);
|
||||
CHECK_EQ(sys->name(), "SYS_WalkForward");
|
||||
query = KQueryByIndex(-125);
|
||||
// sys->setParam<bool>("trace", true);
|
||||
sys->setParam<bool>("parallel", false);
|
||||
sys->run(stk, query);
|
||||
|
||||
auto delay_request = sys->getBuyTradeRequest();
|
||||
@ -71,7 +72,7 @@ TEST_CASE("test_SYS_WalkForword") {
|
||||
CHECK_EQ(tr_list1[i + 1], tr_list2[i]);
|
||||
}
|
||||
|
||||
/** @arg 多个候选系统 */
|
||||
/** @arg 多个候选系统, 使用 SE_MaxFundsOptimal */
|
||||
vector<std::pair<int, int>> params{{3, 5}, {3, 10}, {5, 10}, {5, 20}};
|
||||
SystemList sys_list;
|
||||
for (const auto& param : params) {
|
||||
@ -79,9 +80,183 @@ TEST_CASE("test_SYS_WalkForword") {
|
||||
}
|
||||
tm->reset();
|
||||
REQUIRE(tm->getTradeList().size() == 1);
|
||||
sys = SYS_WalkForward(sys_list, tm, 30, 20);
|
||||
sys = SYS_WalkForward(sys_list, tm, 30, 20, se);
|
||||
query = KQueryByIndex(-125);
|
||||
// sys->setParam<bool>("trace", true);
|
||||
sys->setParam<bool>("parallel", false);
|
||||
sys->run(stk, query);
|
||||
|
||||
delay_request = sys->getBuyTradeRequest();
|
||||
CHECK_UNARY(delay_request.valid);
|
||||
CHECK_EQ(delay_request.business, BUSINESS_BUY);
|
||||
CHECK_EQ(delay_request.datetime, Datetime(20111205));
|
||||
|
||||
tm = sys->getTM();
|
||||
CHECK_EQ(tm->currentCash(), 98286.0);
|
||||
|
||||
tr_list1 = tm->getTradeList();
|
||||
tr_list2 = sys->getTradeRecordList();
|
||||
CHECK_EQ(tr_list1.size(), tr_list2.size() + 1);
|
||||
for (size_t i = 0, total = tr_list2.size(); i < total; i++) {
|
||||
CHECK_EQ(tr_list1[i + 1], tr_list2[i]);
|
||||
}
|
||||
}
|
||||
|
||||
/** @par 检测点 */
|
||||
TEST_CASE("test_SYS_WalkForword_SE_MaxFundsOptimal_parallel") {
|
||||
Stock stk = getStock("sz000001");
|
||||
KQuery query = KQueryByIndex(-50);
|
||||
TMPtr tm = crtTM();
|
||||
|
||||
/** @arg 只有一个候选系统, 使用 SE_MaxFundsOptimal */
|
||||
auto se = SE_MaxFundsOptimal();
|
||||
auto sys = SYS_WalkForward(SystemList{create_test_sys(3, 5)}, tm, 30, 20, se);
|
||||
CHECK_EQ(sys->name(), "SYS_WalkForward");
|
||||
query = KQueryByIndex(-125);
|
||||
sys->setParam<bool>("parallel", true);
|
||||
sys->run(stk, query);
|
||||
|
||||
auto delay_request = sys->getBuyTradeRequest();
|
||||
CHECK_UNARY(delay_request.valid);
|
||||
CHECK_EQ(delay_request.business, BUSINESS_BUY);
|
||||
CHECK_EQ(delay_request.datetime, Datetime(20111205));
|
||||
|
||||
tm = sys->getTM();
|
||||
CHECK_EQ(tm->currentCash(), 99328.0);
|
||||
|
||||
auto tr_list1 = tm->getTradeList();
|
||||
auto tr_list2 = sys->getTradeRecordList();
|
||||
CHECK_EQ(tr_list1.size(), tr_list2.size() + 1);
|
||||
for (size_t i = 0, total = tr_list2.size(); i < total; i++) {
|
||||
CHECK_EQ(tr_list1[i + 1], tr_list2[i]);
|
||||
}
|
||||
|
||||
/** @arg 多个候选系统, 使用 SE_MaxFundsOptimal */
|
||||
vector<std::pair<int, int>> params{{3, 5}, {3, 10}, {5, 10}, {5, 20}};
|
||||
SystemList sys_list;
|
||||
for (const auto& param : params) {
|
||||
sys_list.emplace_back(create_test_sys(param.first, param.second));
|
||||
}
|
||||
tm->reset();
|
||||
REQUIRE(tm->getTradeList().size() == 1);
|
||||
sys = SYS_WalkForward(sys_list, tm, 30, 20, se);
|
||||
query = KQueryByIndex(-125);
|
||||
sys->setParam<bool>("parallel", true);
|
||||
sys->run(stk, query);
|
||||
|
||||
delay_request = sys->getBuyTradeRequest();
|
||||
CHECK_UNARY(delay_request.valid);
|
||||
CHECK_EQ(delay_request.business, BUSINESS_BUY);
|
||||
CHECK_EQ(delay_request.datetime, Datetime(20111205));
|
||||
|
||||
tm = sys->getTM();
|
||||
CHECK_EQ(tm->currentCash(), 98286.0);
|
||||
|
||||
tr_list1 = tm->getTradeList();
|
||||
tr_list2 = sys->getTradeRecordList();
|
||||
CHECK_EQ(tr_list1.size(), tr_list2.size() + 1);
|
||||
for (size_t i = 0, total = tr_list2.size(); i < total; i++) {
|
||||
CHECK_EQ(tr_list1[i + 1], tr_list2[i]);
|
||||
}
|
||||
}
|
||||
|
||||
/** @par 检测点 */
|
||||
TEST_CASE("test_SYS_WalkForword_SE_PerformanceOptimal_not_parallel") {
|
||||
Stock stk = getStock("sz000001");
|
||||
KQuery query = KQueryByIndex(-50);
|
||||
TMPtr tm = crtTM();
|
||||
|
||||
/** @arg 只有一个候选系统, 使用 SE_MaxFundsOptimal */
|
||||
auto se = SE_PerformanceOptimal("当前总资产");
|
||||
auto sys = SYS_WalkForward(SystemList{create_test_sys(3, 5)}, tm, 30, 20, se);
|
||||
CHECK_EQ(sys->name(), "SYS_WalkForward");
|
||||
query = KQueryByIndex(-125);
|
||||
sys->setParam<bool>("parallel", false);
|
||||
sys->run(stk, query);
|
||||
|
||||
auto delay_request = sys->getBuyTradeRequest();
|
||||
CHECK_UNARY(delay_request.valid);
|
||||
CHECK_EQ(delay_request.business, BUSINESS_BUY);
|
||||
CHECK_EQ(delay_request.datetime, Datetime(20111205));
|
||||
|
||||
tm = sys->getTM();
|
||||
CHECK_EQ(tm->currentCash(), 99328.0);
|
||||
|
||||
auto tr_list1 = tm->getTradeList();
|
||||
auto tr_list2 = sys->getTradeRecordList();
|
||||
CHECK_EQ(tr_list1.size(), tr_list2.size() + 1);
|
||||
for (size_t i = 0, total = tr_list2.size(); i < total; i++) {
|
||||
CHECK_EQ(tr_list1[i + 1], tr_list2[i]);
|
||||
}
|
||||
|
||||
/** @arg 多个候选系统, 使用 SE_MaxFundsOptimal */
|
||||
vector<std::pair<int, int>> params{{3, 5}, {3, 10}, {5, 10}, {5, 20}};
|
||||
SystemList sys_list;
|
||||
for (const auto& param : params) {
|
||||
sys_list.emplace_back(create_test_sys(param.first, param.second));
|
||||
}
|
||||
tm->reset();
|
||||
REQUIRE(tm->getTradeList().size() == 1);
|
||||
sys = SYS_WalkForward(sys_list, tm, 30, 20, se);
|
||||
query = KQueryByIndex(-125);
|
||||
sys->setParam<bool>("parallel", false);
|
||||
sys->run(stk, query);
|
||||
|
||||
delay_request = sys->getBuyTradeRequest();
|
||||
CHECK_UNARY(delay_request.valid);
|
||||
CHECK_EQ(delay_request.business, BUSINESS_BUY);
|
||||
CHECK_EQ(delay_request.datetime, Datetime(20111205));
|
||||
|
||||
tm = sys->getTM();
|
||||
CHECK_EQ(tm->currentCash(), 98286.0);
|
||||
|
||||
tr_list1 = tm->getTradeList();
|
||||
tr_list2 = sys->getTradeRecordList();
|
||||
CHECK_EQ(tr_list1.size(), tr_list2.size() + 1);
|
||||
for (size_t i = 0, total = tr_list2.size(); i < total; i++) {
|
||||
CHECK_EQ(tr_list1[i + 1], tr_list2[i]);
|
||||
}
|
||||
}
|
||||
|
||||
/** @par 检测点 */
|
||||
TEST_CASE("test_SYS_WalkForword_SE_PerformanceOptimal_parallel") {
|
||||
Stock stk = getStock("sz000001");
|
||||
KQuery query = KQueryByIndex(-50);
|
||||
TMPtr tm = crtTM();
|
||||
|
||||
/** @arg 只有一个候选系统, 使用 SE_MaxFundsOptimal */
|
||||
auto se = SE_PerformanceOptimal("当前总资产");
|
||||
auto sys = SYS_WalkForward(SystemList{create_test_sys(3, 5)}, tm, 30, 20, se);
|
||||
CHECK_EQ(sys->name(), "SYS_WalkForward");
|
||||
query = KQueryByIndex(-125);
|
||||
sys->setParam<bool>("parallel", true);
|
||||
sys->run(stk, query);
|
||||
|
||||
auto delay_request = sys->getBuyTradeRequest();
|
||||
CHECK_UNARY(delay_request.valid);
|
||||
CHECK_EQ(delay_request.business, BUSINESS_BUY);
|
||||
CHECK_EQ(delay_request.datetime, Datetime(20111205));
|
||||
|
||||
tm = sys->getTM();
|
||||
CHECK_EQ(tm->currentCash(), 99328.0);
|
||||
|
||||
auto tr_list1 = tm->getTradeList();
|
||||
auto tr_list2 = sys->getTradeRecordList();
|
||||
CHECK_EQ(tr_list1.size(), tr_list2.size() + 1);
|
||||
for (size_t i = 0, total = tr_list2.size(); i < total; i++) {
|
||||
CHECK_EQ(tr_list1[i + 1], tr_list2[i]);
|
||||
}
|
||||
|
||||
/** @arg 多个候选系统, 使用 SE_MaxFundsOptimal */
|
||||
vector<std::pair<int, int>> params{{3, 5}, {3, 10}, {5, 10}, {5, 20}};
|
||||
SystemList sys_list;
|
||||
for (const auto& param : params) {
|
||||
sys_list.emplace_back(create_test_sys(param.first, param.second));
|
||||
}
|
||||
tm->reset();
|
||||
REQUIRE(tm->getTradeList().size() == 1);
|
||||
sys = SYS_WalkForward(sys_list, tm, 30, 20, se);
|
||||
query = KQueryByIndex(-125);
|
||||
sys->setParam<bool>("parallel", true);
|
||||
sys->run(stk, query);
|
||||
|
||||
delay_request = sys->getBuyTradeRequest();
|
||||
|
Loading…
Reference in New Issue
Block a user