调整 OptimalSelector 代码

This commit is contained in:
fasiondog 2024-09-16 21:36:20 +08:00
parent 4a37c670e2
commit 352badc28c
3 changed files with 71 additions and 58 deletions

View File

@ -77,9 +77,6 @@ void OptimalSelector::calculate(const SystemList& pf_realSysList, const KQuery&
bool trace = getParam<bool>("trace");
CLS_INFO_IF(trace, "candidate sys list size: {}", m_pro_sys_list.size());
// for (const auto& sys : m_pro_sys_list) {
// HKU_DEBUG_IF(trace, "[SE_Optimal] {}", sys->name());
// }
size_t train_len = static_cast<size_t>(getParam<int>("train_len"));
size_t test_len = static_cast<size_t>(getParam<int>("test_len"));
@ -102,18 +99,27 @@ void OptimalSelector::calculate(const SystemList& pf_realSysList, const KQuery&
string key = getParam<string>("key");
int mode = getParam<int>("mode");
CLS_INFO_IF(trace, "statistic key: {}, mode: {}", key, mode);
CLS_INFO_IF(trace, "statistic key: {}, mode: {}", getParam<string>("key"),
getParam<int>("mode"));
if (getParam<bool>("parallel")) {
_calculate_parallel(train_ranges, dates);
return;
_calculate_parallel(train_ranges, dates, key, mode, test_len, trace);
} else {
_calculate_single(train_ranges, dates, key, mode, test_len, trace);
}
m_calculated = true;
}
void OptimalSelector::_calculate_single(const vector<std::pair<size_t, size_t>>& train_ranges,
const DatetimeList& dates, const string& key, int mode,
size_t test_len, bool trace) {
size_t dates_len = dates.size();
Performance per;
for (size_t i = 0, total = train_ranges.size(); i < total; i++) {
Datetime start_date = dates[train_ranges[i].first];
Datetime end_date = dates[train_ranges[i].second];
KQuery q = KQueryByDate(start_date, end_date, query.kType(), query.recoverType());
KQuery q = KQueryByDate(start_date, end_date, m_query.kType(), m_query.recoverType());
CLS_INFO_IF(trace, "iteration: {}|{}, range: {}", i + 1, total, q);
SYSPtr selected_sys;
if (m_pro_sys_list.size() == 1) {
@ -144,37 +150,38 @@ void OptimalSelector::calculate(const SystemList& pf_realSysList, const KQuery&
}
}
HKU_ASSERT(selected_sys);
size_t test_start = train_ranges[i].second;
size_t test_end = test_start + test_len;
if (test_end > dates_len) {
test_end = dates_len;
}
selected_sys->reset();
selected_sys = selected_sys->clone();
for (size_t pos = test_start; pos < test_end; pos++) {
m_sys_dict[dates[pos]] = selected_sys;
}
if (test_end < dates_len) {
m_run_ranges.emplace_back(std::make_pair(dates[test_start], dates[test_end]));
} else {
m_run_ranges.emplace_back(
std::make_pair(dates[test_start], dates[test_end - 1] + Seconds(1)));
}
CLS_INFO_IF(trace, "iteration: {}, selected_sys: {}", i + 1, selected_sys->name());
}
if (selected_sys) {
selected_sys->reset();
selected_sys = selected_sys->clone();
m_calculated = true;
size_t test_start = train_ranges[i].second;
size_t test_end = test_start + test_len;
if (test_end > dates_len) {
test_end = dates_len;
}
for (size_t pos = test_start; pos < test_end; pos++) {
m_sys_dict[dates[pos]] = selected_sys;
}
if (test_end < dates_len) {
m_run_ranges.emplace_back(std::make_pair(dates[test_start], dates[test_end]));
} else {
m_run_ranges.emplace_back(
std::make_pair(dates[test_start], dates[test_end - 1] + Seconds(1)));
}
CLS_INFO_IF(trace, "iteration: {}, selected_sys: {}", i + 1, selected_sys->name());
}
}
}
void OptimalSelector::_calculate_parallel(const vector<std::pair<size_t, size_t>>& train_ranges,
const DatetimeList& dates) {
const DatetimeList& dates, const string& key, int mode,
size_t test_len, bool trace) {
auto sys_list = parallel_for_index(
0, train_ranges.size(), [this, &train_ranges, &dates, query = m_query](size_t i) {
bool trace = getParam<bool>("trace");
string key = getParam<string>("key");
int mode = getParam<int>("mode");
0, train_ranges.size(),
[this, &train_ranges, &dates, query = m_query, trace, key, mode](size_t i) {
Datetime start_date = dates[train_ranges[i].first];
Datetime end_date = dates[train_ranges[i].second];
KQuery q = KQueryByDate(start_date, end_date, query.kType(), query.recoverType());
@ -210,35 +217,36 @@ void OptimalSelector::_calculate_parallel(const vector<std::pair<size_t, size_t>
}
}
HKU_IF_RETURN(!selected_sys, selected_sys);
selected_sys->reset();
return selected_sys->clone();
return selected_sys;
});
// size_t train_len = static_cast<size_t>(getParam<int>("train_len"));
bool trace = getParam<bool>("trace");
size_t dates_len = dates.size();
size_t test_len = static_cast<size_t>(getParam<int>("test_len"));
for (size_t i = 0, total = train_ranges.size(); i < total; i++) {
size_t test_start = train_ranges[i].second;
size_t test_end = test_start + test_len;
if (test_end > dates_len) {
test_end = dates_len;
auto& selected_sys = sys_list[i];
if (selected_sys) {
selected_sys->reset();
selected_sys = selected_sys->clone();
size_t test_start = train_ranges[i].second;
size_t test_end = test_start + test_len;
if (test_end > dates_len) {
test_end = dates_len;
}
for (size_t pos = test_start; pos < test_end; pos++) {
m_sys_dict[dates[pos]] = selected_sys;
}
if (test_end < dates_len) {
m_run_ranges.emplace_back(std::make_pair(dates[test_start], dates[test_end]));
} else {
m_run_ranges.emplace_back(
std::make_pair(dates[test_start], dates[test_end - 1] + Seconds(1)));
}
CLS_INFO_IF(trace, "iteration: {}, selected_sys: {}", i + 1, selected_sys->name());
}
const auto& selected_sys = sys_list[i];
for (size_t pos = test_start; pos < test_end; pos++) {
m_sys_dict[dates[pos]] = selected_sys;
}
if (test_end < dates_len) {
m_run_ranges.emplace_back(std::make_pair(dates[test_start], dates[test_end]));
} else {
m_run_ranges.emplace_back(
std::make_pair(dates[test_start], dates[test_end - 1] + Seconds(1)));
}
CLS_INFO_IF(trace, "iteration: {}, selected_sys: {}", i + 1, selected_sys->name());
}
m_calculated = true;
}
SEPtr HKU_API SE_Optimal() {

View File

@ -30,8 +30,13 @@ public:
}
private:
void _calculate_single(const vector<std::pair<size_t, size_t>>& train_ranges,
const DatetimeList& dates, const string& key, int mode, size_t test_len,
bool trace);
void _calculate_parallel(const vector<std::pair<size_t, size_t>>& train_ranges,
const DatetimeList& dates);
const DatetimeList& dates, const string& key, int mode,
size_t test_len, bool trace);
private:
unordered_map<Datetime, SYSPtr> m_sys_dict;

View File

@ -123,7 +123,7 @@ TEST_CASE("test_SE_Optimal") {
query = KQueryByIndex(-125);
// se->setParam<bool>("trace", true);
se->setParam<bool>("parallel", true);
// se->setParam<bool>("parallel", true);
se->setParam<int>("train_len", 30);
se->setParam<int>("test_len", 20);
se->calculate(SystemList(), query);