mirror of
https://gitee.com/fasiondog/hikyuu.git
synced 2024-12-02 03:48:19 +08:00
调整 OptimalSelector 代码
This commit is contained in:
parent
4a37c670e2
commit
352badc28c
@ -77,9 +77,6 @@ void OptimalSelector::calculate(const SystemList& pf_realSysList, const KQuery&
|
||||
|
||||
bool trace = getParam<bool>("trace");
|
||||
CLS_INFO_IF(trace, "candidate sys list size: {}", m_pro_sys_list.size());
|
||||
// for (const auto& sys : m_pro_sys_list) {
|
||||
// HKU_DEBUG_IF(trace, "[SE_Optimal] {}", sys->name());
|
||||
// }
|
||||
|
||||
size_t train_len = static_cast<size_t>(getParam<int>("train_len"));
|
||||
size_t test_len = static_cast<size_t>(getParam<int>("test_len"));
|
||||
@ -102,18 +99,27 @@ void OptimalSelector::calculate(const SystemList& pf_realSysList, const KQuery&
|
||||
|
||||
string key = getParam<string>("key");
|
||||
int mode = getParam<int>("mode");
|
||||
CLS_INFO_IF(trace, "statistic key: {}, mode: {}", key, mode);
|
||||
CLS_INFO_IF(trace, "statistic key: {}, mode: {}", getParam<string>("key"),
|
||||
getParam<int>("mode"));
|
||||
|
||||
if (getParam<bool>("parallel")) {
|
||||
_calculate_parallel(train_ranges, dates);
|
||||
return;
|
||||
_calculate_parallel(train_ranges, dates, key, mode, test_len, trace);
|
||||
} else {
|
||||
_calculate_single(train_ranges, dates, key, mode, test_len, trace);
|
||||
}
|
||||
|
||||
m_calculated = true;
|
||||
}
|
||||
|
||||
void OptimalSelector::_calculate_single(const vector<std::pair<size_t, size_t>>& train_ranges,
|
||||
const DatetimeList& dates, const string& key, int mode,
|
||||
size_t test_len, bool trace) {
|
||||
size_t dates_len = dates.size();
|
||||
Performance per;
|
||||
for (size_t i = 0, total = train_ranges.size(); i < total; i++) {
|
||||
Datetime start_date = dates[train_ranges[i].first];
|
||||
Datetime end_date = dates[train_ranges[i].second];
|
||||
KQuery q = KQueryByDate(start_date, end_date, query.kType(), query.recoverType());
|
||||
KQuery q = KQueryByDate(start_date, end_date, m_query.kType(), m_query.recoverType());
|
||||
CLS_INFO_IF(trace, "iteration: {}|{}, range: {}", i + 1, total, q);
|
||||
SYSPtr selected_sys;
|
||||
if (m_pro_sys_list.size() == 1) {
|
||||
@ -144,37 +150,38 @@ void OptimalSelector::calculate(const SystemList& pf_realSysList, const KQuery&
|
||||
}
|
||||
}
|
||||
|
||||
HKU_ASSERT(selected_sys);
|
||||
size_t test_start = train_ranges[i].second;
|
||||
size_t test_end = test_start + test_len;
|
||||
if (test_end > dates_len) {
|
||||
test_end = dates_len;
|
||||
}
|
||||
selected_sys->reset();
|
||||
selected_sys = selected_sys->clone();
|
||||
for (size_t pos = test_start; pos < test_end; pos++) {
|
||||
m_sys_dict[dates[pos]] = selected_sys;
|
||||
}
|
||||
if (test_end < dates_len) {
|
||||
m_run_ranges.emplace_back(std::make_pair(dates[test_start], dates[test_end]));
|
||||
} else {
|
||||
m_run_ranges.emplace_back(
|
||||
std::make_pair(dates[test_start], dates[test_end - 1] + Seconds(1)));
|
||||
}
|
||||
CLS_INFO_IF(trace, "iteration: {}, selected_sys: {}", i + 1, selected_sys->name());
|
||||
}
|
||||
if (selected_sys) {
|
||||
selected_sys->reset();
|
||||
selected_sys = selected_sys->clone();
|
||||
|
||||
m_calculated = true;
|
||||
size_t test_start = train_ranges[i].second;
|
||||
size_t test_end = test_start + test_len;
|
||||
if (test_end > dates_len) {
|
||||
test_end = dates_len;
|
||||
}
|
||||
|
||||
for (size_t pos = test_start; pos < test_end; pos++) {
|
||||
m_sys_dict[dates[pos]] = selected_sys;
|
||||
}
|
||||
|
||||
if (test_end < dates_len) {
|
||||
m_run_ranges.emplace_back(std::make_pair(dates[test_start], dates[test_end]));
|
||||
} else {
|
||||
m_run_ranges.emplace_back(
|
||||
std::make_pair(dates[test_start], dates[test_end - 1] + Seconds(1)));
|
||||
}
|
||||
|
||||
CLS_INFO_IF(trace, "iteration: {}, selected_sys: {}", i + 1, selected_sys->name());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void OptimalSelector::_calculate_parallel(const vector<std::pair<size_t, size_t>>& train_ranges,
|
||||
const DatetimeList& dates) {
|
||||
const DatetimeList& dates, const string& key, int mode,
|
||||
size_t test_len, bool trace) {
|
||||
auto sys_list = parallel_for_index(
|
||||
0, train_ranges.size(), [this, &train_ranges, &dates, query = m_query](size_t i) {
|
||||
bool trace = getParam<bool>("trace");
|
||||
string key = getParam<string>("key");
|
||||
int mode = getParam<int>("mode");
|
||||
|
||||
0, train_ranges.size(),
|
||||
[this, &train_ranges, &dates, query = m_query, trace, key, mode](size_t i) {
|
||||
Datetime start_date = dates[train_ranges[i].first];
|
||||
Datetime end_date = dates[train_ranges[i].second];
|
||||
KQuery q = KQueryByDate(start_date, end_date, query.kType(), query.recoverType());
|
||||
@ -210,35 +217,36 @@ void OptimalSelector::_calculate_parallel(const vector<std::pair<size_t, size_t>
|
||||
}
|
||||
}
|
||||
|
||||
HKU_IF_RETURN(!selected_sys, selected_sys);
|
||||
|
||||
selected_sys->reset();
|
||||
return selected_sys->clone();
|
||||
return selected_sys;
|
||||
});
|
||||
|
||||
// size_t train_len = static_cast<size_t>(getParam<int>("train_len"));
|
||||
bool trace = getParam<bool>("trace");
|
||||
size_t dates_len = dates.size();
|
||||
size_t test_len = static_cast<size_t>(getParam<int>("test_len"));
|
||||
for (size_t i = 0, total = train_ranges.size(); i < total; i++) {
|
||||
size_t test_start = train_ranges[i].second;
|
||||
size_t test_end = test_start + test_len;
|
||||
if (test_end > dates_len) {
|
||||
test_end = dates_len;
|
||||
auto& selected_sys = sys_list[i];
|
||||
if (selected_sys) {
|
||||
selected_sys->reset();
|
||||
selected_sys = selected_sys->clone();
|
||||
|
||||
size_t test_start = train_ranges[i].second;
|
||||
size_t test_end = test_start + test_len;
|
||||
if (test_end > dates_len) {
|
||||
test_end = dates_len;
|
||||
}
|
||||
|
||||
for (size_t pos = test_start; pos < test_end; pos++) {
|
||||
m_sys_dict[dates[pos]] = selected_sys;
|
||||
}
|
||||
|
||||
if (test_end < dates_len) {
|
||||
m_run_ranges.emplace_back(std::make_pair(dates[test_start], dates[test_end]));
|
||||
} else {
|
||||
m_run_ranges.emplace_back(
|
||||
std::make_pair(dates[test_start], dates[test_end - 1] + Seconds(1)));
|
||||
}
|
||||
|
||||
CLS_INFO_IF(trace, "iteration: {}, selected_sys: {}", i + 1, selected_sys->name());
|
||||
}
|
||||
const auto& selected_sys = sys_list[i];
|
||||
for (size_t pos = test_start; pos < test_end; pos++) {
|
||||
m_sys_dict[dates[pos]] = selected_sys;
|
||||
}
|
||||
if (test_end < dates_len) {
|
||||
m_run_ranges.emplace_back(std::make_pair(dates[test_start], dates[test_end]));
|
||||
} else {
|
||||
m_run_ranges.emplace_back(
|
||||
std::make_pair(dates[test_start], dates[test_end - 1] + Seconds(1)));
|
||||
}
|
||||
CLS_INFO_IF(trace, "iteration: {}, selected_sys: {}", i + 1, selected_sys->name());
|
||||
}
|
||||
m_calculated = true;
|
||||
}
|
||||
|
||||
SEPtr HKU_API SE_Optimal() {
|
||||
|
@ -30,8 +30,13 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
void _calculate_single(const vector<std::pair<size_t, size_t>>& train_ranges,
|
||||
const DatetimeList& dates, const string& key, int mode, size_t test_len,
|
||||
bool trace);
|
||||
|
||||
void _calculate_parallel(const vector<std::pair<size_t, size_t>>& train_ranges,
|
||||
const DatetimeList& dates);
|
||||
const DatetimeList& dates, const string& key, int mode,
|
||||
size_t test_len, bool trace);
|
||||
|
||||
private:
|
||||
unordered_map<Datetime, SYSPtr> m_sys_dict;
|
||||
|
@ -123,7 +123,7 @@ TEST_CASE("test_SE_Optimal") {
|
||||
|
||||
query = KQueryByIndex(-125);
|
||||
// se->setParam<bool>("trace", true);
|
||||
se->setParam<bool>("parallel", true);
|
||||
// se->setParam<bool>("parallel", true);
|
||||
se->setParam<int>("train_len", 30);
|
||||
se->setParam<int>("test_len", 20);
|
||||
se->calculate(SystemList(), query);
|
||||
|
Loading…
Reference in New Issue
Block a user