优化 MultiFactorSelector, 补充测试

This commit is contained in:
fasiondog 2024-06-09 17:07:18 +08:00
parent 04e8b3a846
commit 917171c3a0
7 changed files with 130 additions and 79 deletions

View File

@ -24,7 +24,7 @@ SelectorPtr HKU_API SE_MultiFactor(const MFPtr& mf, int topn = 10);
/**
* MultiFactor
* @param src_inds
* @param topn topn
* @param topn topn 0
* @param ic_n ic ic_n
* @param ic_rolling_n IC IC n
* @param ref_stk 使 sh000300 300

View File

@ -17,6 +17,7 @@ namespace hku {
MultiFactorSelector::MultiFactorSelector() : SelectorBase("SE_MultiFactor") {
// 只选择发出买入信号的系统,此时选中的系统会变成资产平均分配,参考 AF 参数ignore_zero_weight
setParam<bool>("only_should_buy", false);
setParam<bool>("ignore_null", true); // 是否忽略 MF 中 score 值为 nan 的证券
setParam<int>("topn", 10);
setParam<int>("ic_n", 5);
setParam<int>("ic_rolling_n", 120);
@ -29,6 +30,7 @@ MultiFactorSelector::MultiFactorSelector(const MFPtr& mf, int topn)
: SelectorBase("SE_MultiFactor"), m_mf(mf) {
HKU_CHECK(mf, "mf is null!");
setParam<bool>("only_should_buy", false);
setParam<bool>("ignore_null", true);
setParam<int>("topn", topn);
setParam<int>("ic_n", mf->getParam<int>("ic_n"));
@ -43,10 +45,7 @@ MultiFactorSelector::MultiFactorSelector(const MFPtr& mf, int topn)
MultiFactorSelector::~MultiFactorSelector() {}
void MultiFactorSelector::_checkParam(const string& name) const {
if ("topn" == name) {
int topn = getParam<int>("topn");
HKU_ASSERT(topn > 0);
} else if ("ic_n" == name) {
if ("ic_n" == name) {
HKU_ASSERT(getParam<int>("ic_n") >= 1);
} else if ("ic_rolling_n" == name) {
HKU_ASSERT(getParam<int>("ic_rolling_n") >= 1);
@ -79,8 +78,19 @@ bool MultiFactorSelector::isMatchAF(const AFPtr& af) {
SystemWeightList MultiFactorSelector::getSelected(Datetime date) {
SystemWeightList ret;
auto scores = m_mf->getScores(date, 0, getParam<int>("topn"),
[](const ScoreRecord& sc) { return !std::isnan(sc.value); });
int topn = getParam<int>("topn");
if (topn <= 0) {
topn = std::numeric_limits<int>::max();
}
ScoreRecordList scores;
if (getParam<bool>("ignore_null")) {
scores = m_mf->getScores(date, 0, getParam<int>("topn"),
[](const ScoreRecord& sc) { return !std::isnan(sc.value); });
} else {
scores = m_mf->getScores(date, 0, getParam<int>("topn"));
}
if (getParam<bool>("only_should_buy")) {
for (const auto& sc : scores) {
auto sys = m_stk_sys_dict[sc.stock];
@ -124,10 +134,10 @@ void MultiFactorSelector::_calculate() {
HKU_THROW("Invalid mode: {}", mode);
}
} else {
m_mf->setQuery(query);
m_mf->setRefIndicators(m_inds);
m_mf->setRefStock(ref_stk);
m_mf->setStockList(stks);
m_mf->setQuery(query);
m_mf->setParam<int>("ic_n", ic_n);
if (m_mf->haveParam("ic_rolling_n")) {
m_mf->setParam<int>("ic_rolling_n", ic_rolling_n);
@ -151,7 +161,7 @@ SelectorPtr HKU_API SE_MultiFactor(const IndicatorList& src_inds, int topn = 10,
p->setParam<int>("topn", topn);
p->setParam<int>("ic_n", ic_n);
p->setParam<int>("ic_rolling_n", ic_rolling_n);
p->setParam<Stock>("ref_stock", ref_stk);
p->setParam<Stock>("ref_stk", ref_stk);
p->setParam<string>("mode", mode);
return p;
}

View File

@ -26,6 +26,7 @@ public:
virtual void _calculate() override;
void setIndicators(const IndicatorList& inds) {
HKU_ASSERT(!inds.empty());
m_inds = inds;
}

View File

@ -7,6 +7,7 @@
#include "doctest/doctest.h"
#include <hikyuu/StockManager.h>
#include <hikyuu/trade_manage/crt/crtTM.h>
#include <hikyuu/trade_sys/system/crt/SYS_Simple.h>
#include <hikyuu/trade_sys/selector/crt/SE_Fixed.h>
#include <hikyuu/trade_sys/signal/crt/SG_Cross.h>
@ -17,7 +18,7 @@
using namespace hku;
/**
* @defgroup test_Selector test_Selector
* @defgroup test_SE_Fixed test_SE_Fixed
* @ingroup test_hikyuu_trade_sys_suite
* @{
*/
@ -80,4 +81,44 @@ TEST_CASE("test_SE_Fixed") {
CHECK_EQ(sm["sz000002"], result[2].sys->getStock());
}
//-----------------------------------------------------------------------------
// test export
//-----------------------------------------------------------------------------
#if HKU_SUPPORT_SERIALIZATION
/** @par 检测点 */
TEST_CASE("test_SE_Fixed_export") {
StockManager& sm = StockManager::instance();
string filename(sm.tmpdir());
filename += "/SE_FIXED.xml";
TMPtr tm = crtTM(Datetime(20010101), 100000);
SGPtr sg = SG_Cross(MA(CLOSE(), 5), MA(CLOSE(), 10));
MMPtr mm = MM_FixedCount(100);
SYSPtr sys = SYS_Simple();
sys->setTM(tm);
sys->setSG(sg);
sys->setMM(mm);
StockList stkList;
stkList.push_back(sm["sh600000"]);
stkList.push_back(sm["sz000001"]);
SEPtr se1 = SE_Fixed(stkList, sys);
{
std::ofstream ofs(filename);
boost::archive::xml_oarchive oa(ofs);
oa << BOOST_SERIALIZATION_NVP(se1);
}
SEPtr se2;
{
std::ifstream ifs(filename);
boost::archive::xml_iarchive ia(ifs);
ia >> BOOST_SERIALIZATION_NVP(se2);
}
CHECK_EQ(se1->name(), se2->name());
}
#endif /* #if HKU_SUPPORT_SERIALIZATION */
/** @} */

View File

@ -0,0 +1,66 @@
/*
* Copyright (c) 2024 hikyuu.org
*
* Created on: 2024-06-09
* Author: fasiondog
*/
#include "doctest/doctest.h"
#include <hikyuu/StockManager.h>
#include <hikyuu/trade_sys/system/crt/SYS_Simple.h>
#include <hikyuu/trade_sys/selector/crt/SE_MultiFactor.h>
#include <hikyuu/trade_sys/selector/imp/MultiFactorSelector.h>
#include <hikyuu/trade_sys/signal/crt/SG_Cycle.h>
#include <hikyuu/trade_sys/moneymanager/crt/MM_Nothing.h>
#include <hikyuu/trade_manage/crt/crtTM.h>
#include <hikyuu/indicator/crt/KDATA.h>
#include <hikyuu/indicator/crt/MA.h>
#include <hikyuu/indicator/crt/AMA.h>
#include <hikyuu/indicator/crt/EMA.h>
using namespace hku;
/**
* @defgroup test_SE_MultiFactor test_SE_MultiFactor
* @ingroup test_hikyuu_trade_sys_suite
* @{
*/
/** @par 检测点 */
TEST_CASE("test_SE_MultiFactor") {
StockManager& sm = StockManager::instance();
StockList stks{sm["sh600004"], sm["sh600005"], sm["sz000001"], sm["sz000002"]};
Stock ref_stk = sm["sh000001"];
KQuery query = KQuery(-100);
IndicatorList src_inds{MA(CLOSE()), EMA(CLOSE())};
auto sys = SYS_Simple(crtTM(), MM_Nothing());
sys->setSG(SG_Cycle());
sys->setParam<bool>("buy_delay", false);
/** @arg 测试试图修改参数值为非法值 */
auto ret = SE_MultiFactor(src_inds);
CHECK_THROWS(ret->setParam<int>("ic_n", 0));
CHECK_THROWS(ret->setParam<int>("ic_rolling_n", 0));
CHECK_THROWS(ret->setParam<string>("mode", "MF"));
/** @arg src_inds 为空,其余为默认参数 */
CHECK_THROWS(SE_MultiFactor(IndicatorList{}));
/** @arg 默认参数 */
ret = SE_MultiFactor(src_inds, 10, 5, 120, ref_stk);
ret->addStockList(stks, sys);
auto proto_list = ret->getProtoSystemList();
ret->calculate(proto_list, query);
}
//-----------------------------------------------------------------------------
// test export
//-----------------------------------------------------------------------------
#if HKU_SUPPORT_SERIALIZATION
/** @par 检测点 */
TEST_CASE("test_SE_MultiFactor_export") {}
#endif /* #if HKU_SUPPORT_SERIALIZATION */
/** @} */

View File

@ -1,68 +0,0 @@
/*
* test_export.cpp
*
* Created on: 2018-2-10
* Author: fasiondog
*/
#include "doctest/doctest.h"
#include <hikyuu/config.h>
#if HKU_SUPPORT_SERIALIZATION
#include <fstream>
#include <boost/archive/xml_oarchive.hpp>
#include <boost/archive/xml_iarchive.hpp>
#include <hikyuu/StockManager.h>
#include <hikyuu/trade_manage/crt/crtTM.h>
#include <hikyuu/trade_sys/signal/crt/SG_Cross.h>
#include <hikyuu/trade_sys/moneymanager/crt/MM_FixedCount.h>
#include <hikyuu/indicator/crt/KDATA.h>
#include <hikyuu/indicator/crt/MA.h>
#include <hikyuu/trade_sys/selector/crt/SE_Fixed.h>
#include <hikyuu/trade_sys/system/crt/SYS_Simple.h>
using namespace hku;
/**
* @defgroup test_selector_serialization test_selector_serialization
* @ingroup test_hikyuu_trade_sys_suite
* @{
*/
/** @par 检测点 */
TEST_CASE("test_SE_FIXED_export") {
StockManager& sm = StockManager::instance();
string filename(sm.tmpdir());
filename += "/SE_FIXED.xml";
TMPtr tm = crtTM(Datetime(20010101), 100000);
SGPtr sg = SG_Cross(MA(CLOSE(), 5), MA(CLOSE(), 10));
MMPtr mm = MM_FixedCount(100);
SYSPtr sys = SYS_Simple();
sys->setTM(tm);
sys->setSG(sg);
sys->setMM(mm);
StockList stkList;
stkList.push_back(sm["sh600000"]);
stkList.push_back(sm["sz000001"]);
SEPtr se1 = SE_Fixed(stkList, sys);
{
std::ofstream ofs(filename);
boost::archive::xml_oarchive oa(ofs);
oa << BOOST_SERIALIZATION_NVP(se1);
}
SEPtr se2;
{
std::ifstream ifs(filename);
boost::archive::xml_iarchive ia(ifs);
ia >> BOOST_SERIALIZATION_NVP(se2);
}
CHECK_EQ(se1->name(), se2->name());
}
/** @} */
#endif /* HKU_SUPPORT_SERIALIZATION */

View File

@ -187,8 +187,9 @@ void export_Selector(py::module& m) {
- :
:param sequense(Indicator) inds:
:param Stock ref_stk: ( sh000300 300)
:param int topn: topn 0
:param int ic_n: IC N
:param int ic_rolling_n: IC
:param Stock ref_stk: ( sh000300 300)
:param str mode: "MF_ICIRWeight" | "MF_ICWeight" | "MF_EqualWeight" )");
}