调整 CORR 实现

This commit is contained in:
fasiondog 2024-03-10 03:19:37 +08:00
parent 58178ec575
commit a9b64b973d
14 changed files with 208 additions and 135 deletions

View File

@ -207,7 +207,7 @@
:param Indicator ind1: 指标1
:param Indicator ind2: 指标2
:param int n: 按指定 n 的长度计算两个 ind 直接数据相关系数
:param int n: 按指定 n 的长度计算两个 ind 直接数据相关系数。如果为0使用输入的ind长度。
:rtype: Indicator

View File

@ -305,14 +305,4 @@ Indicator HKU_API IF(const Indicator& x, Indicator::value_t a, Indicator::value_
return IF(x, CVAL(x, a), CVAL(x, b));
}
Indicator HKU_API CORR(const Indicator& ind1, const Indicator& ind2, int n) {
HKU_ERROR_IF_RETURN(!ind1.getImp() || !ind2.getImp(), Indicator(),
"ind1 or ind2 is Null Indicator!");
HKU_ERROR_IF_RETURN(n < 2, Indicator(), "Invalid param n: {} (need >= 2)", n);
IndicatorImpPtr p = make_shared<IndicatorImp>("CORR");
p->setParam<int>("n", n);
p->add(IndicatorImp::CORR, ind1.getImp(), ind2.getImp());
return p->calculate();
}
} /* namespace hku */

View File

@ -379,14 +379,6 @@ Indicator HKU_API IF(const Indicator& x, Indicator::value_t a, const Indicator&
Indicator HKU_API IF(const Indicator& x, const Indicator& a, Indicator::value_t b);
Indicator HKU_API IF(const Indicator& x, Indicator::value_t a, Indicator::value_t b);
/**
*
* @param ind1 1
* @param ind2 2
* @ingroup Indicator
*/
Indicator HKU_API CORR(const Indicator& ind1, const Indicator& ind2, int n);
} /* namespace hku */
#if FMT_VERSION >= 90000

View File

@ -93,10 +93,6 @@ string HKU_API getOPTypeName(IndicatorImp::OPType op) {
name = "IF";
break;
case IndicatorImp::CORR:
name = "CORR";
break;
default:
name = "UNKNOWN";
break;
@ -534,10 +530,6 @@ string IndicatorImp::formula() const {
<< m_right->formula() << ")";
break;
case CORR:
buf << m_name << "(" << m_left->formula() << ", " << m_right->formula() << ")";
break;
default:
HKU_ERROR("Wrong optype! {}", int(m_optype));
break;
@ -797,10 +789,6 @@ Indicator IndicatorImp::calculate() {
execute_if();
break;
case CORR:
execute_corr();
break;
default:
HKU_ERROR("Unkown Indicator::OPType! {}", int(m_optype));
break;
@ -1463,88 +1451,6 @@ void IndicatorImp::execute_if() {
}
}
void IndicatorImp::execute_corr() {
m_right->calculate();
m_left->calculate();
const IndicatorImp *maxp, *minp;
if (m_right->size() > m_left->size()) {
maxp = m_right.get();
minp = m_left.get();
} else {
maxp = m_left.get();
minp = m_right.get();
}
size_t total = maxp->size();
size_t discard = maxp->size() - minp->size() + minp->discard();
if (discard < maxp->discard()) {
discard = maxp->discard();
}
// 结果 0 存放相关系数结果
// 结果 1 存放协方差COV结果
_readyBuffer(total, 2);
int n = getParam<int>("n");
if (n < 2 || discard + 2 > total) {
setDiscard(total);
return;
}
size_t startPos = discard;
size_t first_end = startPos + n >= total ? total : startPos + n;
value_t kx = maxp->get(discard);
value_t ky = minp->get(discard);
value_t ex = 0.0, ey = 0.0, exy = 0.0, varx = 0.0, vary = 0.0, cov = 0.0;
value_t ex2 = 0.0, ey2 = 0.0;
value_t ix, iy;
auto *dst0 = this->data(0);
auto *dst1 = this->data(1);
auto const *maxdata = maxp->data(0);
auto const *mindata = minp->data(0);
for (size_t i = startPos + 1; i < first_end; i++) {
ix = maxdata[i] - kx;
iy = mindata[i] - ky;
ex += ix;
ey += iy;
value_t powx2 = ix * ix;
value_t powy2 = iy * iy;
value_t powxy = ix * iy;
exy += powxy;
ex2 += powx2;
ey2 += powy2;
size_t nobs = i - startPos;
varx = ex2 - powx2 / nobs;
vary = ey2 - powy2 / nobs;
cov = exy - powxy / nobs;
dst0[i] = cov / std::sqrt(varx * vary);
dst1[i] = cov / (nobs - 1);
}
for (size_t i = first_end; i < total; i++) {
ix = maxdata[i] - kx;
iy = mindata[i] - ky;
ex += maxdata[i] - maxdata[i - n];
ey += mindata[i] - mindata[i - n];
value_t preix = maxdata[i - n] - kx;
value_t preiy = mindata[i - n] - ky;
ex2 += ix * ix - preix * preix;
ey2 += iy * iy - preiy * preiy;
exy += ix * iy - preix * preiy;
varx = (ex2 - ex * ex / n);
vary = (ey2 - ey * ey / n);
cov = (exy - ex * ey / n);
dst0[i] = cov / std::sqrt(varx * vary);
dst1[i] = cov / (n - 1);
}
// 修正 discard
setDiscard(discard + 2);
}
void IndicatorImp::_dyn_calculate(const Indicator &ind) {
// SPEND_TIME(IndicatorImp__dyn_calculate);
const auto &ind_param = getIndParamImp("n");

View File

@ -48,7 +48,6 @@ public:
OR, ///< 或
WEAVE, ///< 特殊的,需要两个指标作为参数的指标
OP_IF, /// if操作
CORR, ///< 相关系数,需要两个指标作为参数
INVALID
};
@ -195,8 +194,6 @@ private:
void execute_or();
void execute_weave();
void execute_if();
void execute_corr();
void execute_spearman();
std::vector<IndicatorImpPtr> getAllSubNodes();
void repeatALikeNodes();

View File

@ -27,6 +27,7 @@
#include "crt/BARSSINCE.h"
#include "crt/BETWEEN.h"
#include "crt/CEILING.h"
#include "crt/CORR.h"
#include "crt/COS.h"
#include "crt/COST.h"
#include "crt/COUNT.h"

View File

@ -0,0 +1,23 @@
/*
* Copyright (c) 2024 hikyuu.org
*
* Created on: 2024-03-10
* Author: fasiondog
*/
#pragma once
#include "../Indicator.h"
namespace hku {
/**
*
* @param ind1 1
* @param ind2 2
* @ingroup Indicator
*/
Indicator HKU_API CORR(int n = 10);
Indicator HKU_API CORR(const Indicator& ind1, const Indicator& ind2, int n = 10);
} // namespace hku

View File

@ -5,6 +5,7 @@
* Author: fasiondog
*/
#pragma once
#include "../Indicator.h"
namespace hku {

View File

@ -0,0 +1,123 @@
/*
* Copyright (c) 2024 hikyuu.org
*
* Created on: 2024-03-10
* Author: fasiondog
*/
#include "hikyuu/indicator/crt/ALIGN.h"
#include "ICorr.h"
#if HKU_SUPPORT_SERIALIZATION
BOOST_CLASS_EXPORT(hku::ICorr)
#endif
namespace hku {
ICorr::ICorr() : IndicatorImp("CORR") {
setParam<int>("n", 10);
}
ICorr::ICorr(int n) : IndicatorImp("CORR") {
setParam<int>("n", n);
}
ICorr::ICorr(const Indicator& ref_ind, int n) : IndicatorImp("CORR"), m_ref_ind(ref_ind) {
setParam<int>("n", n);
}
ICorr::~ICorr() {}
bool ICorr::check() {
int n = getParam<int>("n");
return n == 0 || n >= 2;
}
IndicatorImpPtr ICorr::_clone() {
ICorr* p = new ICorr();
p->m_ref_ind = m_ref_ind.clone();
return IndicatorImpPtr(p);
}
void ICorr::_calculate(const Indicator& ind) {
auto k = getContext();
m_ref_ind.setContext(k);
Indicator ref = m_ref_ind;
if (m_ref_ind.size() != ind.size()) {
ref = ALIGN(m_ref_ind, ind);
}
size_t total = ind.size();
_readyBuffer(total, 2);
HKU_IF_RETURN(total == 0, void());
int n = getParam<int>("n");
if (n == 0) {
n = total;
}
m_discard = std::max(ind.discard(), ref.discard());
size_t startPos = m_discard;
size_t first_end = startPos + n >= total ? total : startPos + n;
auto const* datax = ind.data();
auto const* datay = ref.data();
value_t kx = datax[m_discard];
value_t ky = datay[m_discard];
value_t ex = 0.0, ey = 0.0, exy = 0.0, varx = 0.0, vary = 0.0, cov = 0.0;
value_t ex2 = 0.0, ey2 = 0.0;
value_t ix, iy;
auto* dst0 = this->data(0);
auto* dst1 = this->data(1);
for (size_t i = startPos + 1; i < first_end; i++) {
ix = datax[i] - kx;
iy = datay[i] - ky;
ex += ix;
ey += iy;
value_t powx2 = ix * ix;
value_t powy2 = iy * iy;
value_t powxy = ix * iy;
exy += powxy;
ex2 += powx2;
ey2 += powy2;
size_t nobs = i - startPos;
varx = ex2 - powx2 / nobs;
vary = ey2 - powy2 / nobs;
cov = exy - powxy / nobs;
dst0[i] = cov / std::sqrt(varx * vary);
dst1[i] = cov / (nobs - 1);
}
for (size_t i = first_end; i < total; i++) {
ix = datax[i] - kx;
iy = datay[i] - ky;
ex += datax[i] - datax[i - n];
ey += datay[i] - datay[i - n];
value_t preix = datax[i - n] - kx;
value_t preiy = datay[i - n] - ky;
ex2 += ix * ix - preix * preix;
ey2 += iy * iy - preiy * preiy;
exy += ix * iy - preix * preiy;
varx = (ex2 - ex * ex / n);
vary = (ey2 - ey * ey / n);
cov = (exy - ex * ey / n);
dst0[i] = cov / std::sqrt(varx * vary);
dst1[i] = cov / (n - 1);
}
// 修正 discard
m_discard = (m_discard + 2 < total) ? m_discard + 2 : total;
}
Indicator HKU_API CORR(int n) {
return make_shared<ICorr>(n);
}
Indicator HKU_API CORR(const Indicator& ind1, const Indicator& ind2, int n) {
auto p = make_shared<ICorr>(ind2, n);
Indicator result(p);
return result(ind1);
}
} // namespace hku

View File

@ -0,0 +1,42 @@
/*
* Copyright (c) 2024 hikyuu.org
*
* Created on: 2024-03-10
* Author: fasiondog
*/
#pragma once
#include "../Indicator.h"
namespace hku {
class ICorr : public IndicatorImp {
public:
ICorr();
ICorr(int n);
ICorr(const Indicator& ref_ind, int n);
virtual ~ICorr();
virtual bool check() override;
virtual void _calculate(const Indicator& data) override;
virtual IndicatorImpPtr _clone() override;
private:
Indicator m_ref_ind;
//============================================
// 序列化支持
//============================================
#if HKU_SUPPORT_SERIALIZATION
private:
friend class boost::serialization::access;
template <class Archive>
void serialize(Archive& ar, const unsigned int version) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(IndicatorImp);
ar& BOOST_SERIALIZATION_NVP(m_ref_ind);
}
#endif
};
}

View File

@ -76,7 +76,6 @@ static void spearmanLevel(const IndicatorImp::value_t *data, IndicatorImp::value
void ISpearman::_calculate(const Indicator &ind) {
auto k = getContext();
m_ref_ind.setContext(k);
Indicator ref = m_ref_ind;
if (m_ref_ind.size() != ind.size()) {

View File

@ -10,6 +10,7 @@
#include "../test_config.h"
#include <fstream>
#include <hikyuu/StockManager.h>
#include <hikyuu/indicator/crt/CORR.h>
#include <hikyuu/indicator/crt/KDATA.h>
#include <hikyuu/indicator/crt/PRICELIST.h>
@ -41,13 +42,13 @@ TEST_CASE("test_CORR") {
Indicator y = PRICELIST(b);
// 非法参数 n
result = CORR(x, y, 0);
result = CORR(x, y, -1);
CHECK_UNARY(result.empty());
result = CORR(x, y, 1);
CHECK_UNARY(result.empty());
// 正常情况
result = CORR(x, y, a.size());
result = CORR(x, y, 0);
CHECK_EQ(result.name(), "CORR");
CHECK_EQ(result.discard(), 2);
CHECK_EQ(result.size(), a.size());

View File

@ -165,21 +165,14 @@ TEST_CASE("test_SPEARMAN") {
price_t null_value = Null<price_t>();
x = PRICELIST({3., 8., null_value, 4., 7., 2., null_value, null_value});
y = PRICELIST({null_value, 5., 10., 8., null_value, 10., 6., null_value});
// expect = {null_value, , 1., 0.875, 1.};
// nan, 8, nan, 4, nan, 2, nan
// nan, 5, nan, 8, nan, 10, nan,
result = SPEARMAN(x, y, 4);
HKU_INFO("{}", result);
for (size_t i = result.discard(); i < result.size(); i++) {
HKU_INFO("{}: {}", i, result[i]);
}
x = PRICELIST({8., 4., 2.});
y = PRICELIST({5., 8., 10.});
result = SPEARMAN(x, y, x.size());
HKU_INFO("{}", result);
HKU_INFO("{}", std::pow(null_value, 2));
HKU_INFO("{}", 1.0 * 6.0 * null_value / (std::pow(x.size(), 3) - x.size()));
CHECK_EQ(result.name(), "SPEARMAN");
CHECK_EQ(result.discard(), 3);
CHECK_EQ(result.size(), x.size());
CHECK_UNARY(std::isnan(result[0]));
CHECK_EQ(result[5], doctest::Approx(-1.));
CHECK_EQ(result[6], doctest::Approx(-1.));
CHECK_UNARY(std::isnan(result[7]));
}
//-----------------------------------------------------------------------------

View File

@ -466,6 +466,9 @@ Indicator (*ZHBOND10_2)(const DatetimeList&, double) = ZHBOND10;
Indicator (*ZHBOND10_3)(const KData& k, double) = ZHBOND10;
Indicator (*ZHBOND10_4)(const Indicator&, double) = ZHBOND10;
Indicator (*CORR_1)(int) = CORR;
Indicator (*CORR_2)(const Indicator&, const Indicator&, int) = CORR;
Indicator (*SPEARMAN_1)(int) = SPEARMAN;
Indicator (*SPEARMAN_2)(const Indicator&, const Indicator&, int) = SPEARMAN;
@ -850,13 +853,15 @@ void export_Indicator_build_in(py::module& m) {
:param Indicator ind2: 2
:rtype: Indicator)");
m.def("CORR", CORR, R"(CORR(ind1, ind2, n)
m.def("CORR", CORR_1, py::arg("n") = 10);
m.def("CORR", CORR_2, py::arg("ind1"), py::arg("ind2"), py::arg("n") = 10,
R"(CORR(ind1, ind2, n)
ind1 ind2
:param Indicator ind1: 1
:param Indicator ind2: 2
:param int n: n ind
:param int n: n ind 0使ind长度
:rtype: Indicator)");
m.def("IF", IF_1);