Merge pull request #113 from yangrq1018/pr

sqlite kdata driver (support convert interval)
This commit is contained in:
fasiondog 2023-09-24 14:43:58 +08:00 committed by GitHub
commit 4036c182ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 405 additions and 2 deletions

View File

@ -60,6 +60,8 @@ def get_draw_title(kdata):
s1 = u' 线'
elif query.ktype == Query.MIN:
s1 = u' 1线'
elif query.ktype == Query.MIN3:
s1 = u' 3线'
elif query.ktype == Query.MIN5:
s1 = u' 5线'
elif query.ktype == Query.MIN15:
@ -68,13 +70,21 @@ def get_draw_title(kdata):
s1 = u' 30线'
elif query.ktype == Query.MIN60:
s1 = u' 60线'
elif query.ktype == Query.HOUR2:
s1 = u' 2线'
elif query.ktype == Query.HOUR4:
s1 = u' 4线'
elif query.ktype == Query.HOUR6:
s1 = u' 6线'
elif query.ktype == Query.HOUR12:
s1 = u' 12线'
name = stock.name
if stock.code == "":
stitle = "Block(%s) %s" % (stock.id, name) + s1
else:
stitle = stock.market + stock.code + ' ' + name + s1
stitle = stock.market + "/" + stock.code + ' ' + name + s1
return stitle

View File

@ -279,6 +279,7 @@ Query.QUARTER = "QUARTER"
Query.HALFYEAR = "HALFYEAR"
Query.YEAR = "YEAR"
Query.MIN = "MIN"
Query.MIN3 = "MIN3"
Query.MIN5 = "MIN5"
Query.MIN15 = "MIN15"
Query.MIN30 = "MIN30"

View File

@ -119,6 +119,9 @@ for p in preload_config:
kdata_param = Parameter()
kdata_config = ini.options('kdata')
for p in kdata_config:
if p == "convert":
kdata_param[p] = ini.getboolean('kdata', p)
continue
kdata_param[p] = ini.get('kdata', p)
#set_log_level(LOG_LEVEL.INFO)

View File

@ -24,7 +24,7 @@ const string KQuery::YEAR("YEAR");
const string KQuery::MIN3("MIN3");
const string KQuery::HOUR2("HOUR2");
const string KQuery::HOUR4("HOUR4");
const string KQuery::HOUR6("HOUR5");
const string KQuery::HOUR6("HOUR6");
const string KQuery::HOUR12("HOUR12");
// const string KQuery::INVALID_KTYPE("Z");
@ -33,11 +33,36 @@ static vector<string> g_all_ktype{KQuery::MIN, KQuery::MIN5, KQuery::MIN
KQuery::QUARTER, KQuery::HALFYEAR, KQuery::YEAR, KQuery::MIN3,
KQuery::HOUR2, KQuery::HOUR4, KQuery::HOUR6, KQuery::HOUR12};
static const unordered_map<string, int32_t> g_ktype2min{
{KQuery::MIN, 1},
{KQuery::MIN3, 3},
{KQuery::MIN5, 5},
{KQuery::MIN15, 15},
{KQuery::MIN30, 30},
{KQuery::MIN60, 60},
{KQuery::HOUR2, 60 * 2},
{KQuery::HOUR4, 60 * 4},
{KQuery::HOUR6, 60 * 6},
{KQuery::HOUR12, 60 * 12},
{KQuery::DAY, 60 * 24},
{KQuery::WEEK, 60 * 24 * 7},
{KQuery::MONTH, 60 * 24 * 30},
{KQuery::QUARTER, 60 * 24 * 30 * 3},
{KQuery::HALFYEAR, 60 * 24 * 30 * 6},
{KQuery::YEAR, 60 * 24 * 365},
};
// 获取所有的 KType
vector<string>& KQuery::getAllKType() {
return g_all_ktype;
}
int32_t KQuery::getKTypeInMin(KType ktype) {
return g_ktype2min.at(ktype);
}
KQuery::KQuery(Datetime start, Datetime end, KType ktype, RecoverType recoverType)
: m_start(start == Null<Datetime>() ? (int64_t)start.number()
: (int64_t)(start.number() * 100 + start.second())),

View File

@ -72,6 +72,8 @@ public:
/** 获取所有的 KType */
static vector<string>& getAllKType();
static int32_t getKTypeInMin(KType);
/**
*
* @note 线线/线

View File

@ -14,6 +14,7 @@
#include "kdata/mysql/MySQLKDataDriver.h"
#include "kdata/tdx/TdxKDataDriver.h"
#include "kdata/cvs/KDataTempCsvDriver.h"
#include "kdata/sqlite/SQLiteKDataDriver.h"
#include "DataDriverFactory.h"
#include "KDataDriver.h"
@ -40,6 +41,7 @@ void DataDriverFactory::init() {
DataDriverFactory::regKDataDriver(make_shared<H5KDataDriver>());
DataDriverFactory::regKDataDriver(make_shared<MySQLKDataDriver>());
DataDriverFactory::regKDataDriver(make_shared<KDataTempCsvDriver>());
DataDriverFactory::regKDataDriver(make_shared<SQLiteKDataDriver>());
}
void DataDriverFactory::release() {

View File

@ -88,6 +88,12 @@ public:
m_db_name, m_code);
}
string getSelectSQLNoDB() {
return fmt::format(
"select `date`,`open`,`high`, `low`, `close`, `amount`, `count` from `{}`",
m_code);
}
void save(const SQLStatementPtr& st) const {
st->bind(0, m_date, m_open, m_high, m_low, m_close, m_amount, m_count);
}

View File

@ -0,0 +1,288 @@
/*
* SQLiteKDataDriver.cpp
*
* Created on: 20230914
* Author: yangrq1018
*/
#include <fmt/format.h>
#include <boost/algorithm/string.hpp>
#include "SQLiteKDataDriver.h"
#include "../mysql/KRecordTable.h"
namespace hku {
inline bool isBaseKType(const KQuery::KType& ktype) {
return (ktype == KQuery::DAY || ktype == KQuery::MIN || ktype == KQuery::MIN5);
}
inline KQuery::KType getBaseKType(const KQuery::KType& ktype) {
KQuery::KType base_ktype;
if (ktype == KQuery::WEEK || ktype == KQuery::MONTH || ktype == KQuery::QUARTER ||
ktype == KQuery::HALFYEAR || ktype == KQuery::YEAR) {
base_ktype = KQuery::DAY;
} else if (ktype == KQuery::MIN15 || ktype == KQuery::MIN30 || ktype == KQuery::MIN60 ||
ktype == KQuery::HOUR2 || ktype == KQuery::HOUR4 || ktype == KQuery::HOUR6 ||
ktype == KQuery::HOUR12) {
base_ktype = KQuery::MIN5;
} else if (ktype == KQuery::MIN3) {
base_ktype = KQuery::MIN;
} else {
HKU_ERROR("Unable to convert ktype {} to a base ktype", ktype);
}
return base_ktype;
}
SQLiteKDataDriver::SQLiteKDataDriver() : KDataDriver("sqlite") {}
SQLiteKDataDriver::~SQLiteKDataDriver() {}
bool SQLiteKDataDriver::_init() {
HKU_ASSERT_M(m_sqlite_connection_map.empty(), "Maybe repeat initialization!");
// read param from config
StringList keys = m_params.getNameList();
string db_filename;
m_ifConvert = tryGetParam<bool>("convert", false);
HKU_DEBUG("SQLiteKDataDriver: m_ifConvert set to {}", m_ifConvert);
for (auto iter = keys.begin(); iter != keys.end(); ++iter) {
size_t pos = iter->find("_");
if (pos == string::npos || pos == 0 || pos == iter->size() - 1)
continue;
string exchange = iter->substr(0, pos);
string ktype = iter->substr(pos + 1);
to_upper(exchange);
to_upper(ktype);
try {
db_filename = getParam<string>(*iter);
Parameter connect_param;
connect_param.set<string>("db", db_filename);
SQLiteConnectPtr conn(new SQLiteConnect(connect_param));
if (ktype == KQuery::getKTypeName(KQuery::DAY)) {
m_sqlite_connection_map[exchange + "_DAY"] = conn;
if (m_ifConvert) {
m_sqlite_connection_map[exchange + "_WEEK"] = conn;
m_sqlite_connection_map[exchange + "_MONTH"] = conn;
m_sqlite_connection_map[exchange + "_QUARTER"] = conn;
m_sqlite_connection_map[exchange + "_HALFYEAR"] = conn;
m_sqlite_connection_map[exchange + "_YEAR"] = conn;
}
} else if (ktype == KQuery::getKTypeName(KQuery::MIN)) {
m_sqlite_connection_map[exchange + "_MIN"] = conn;
} else if (ktype == KQuery::getKTypeName(KQuery::MIN5)) {
m_sqlite_connection_map[exchange + "_MIN5"] = conn;
if (m_ifConvert) {
m_sqlite_connection_map[exchange + "_MIN15"] = conn;
m_sqlite_connection_map[exchange + "_MIN30"] = conn;
m_sqlite_connection_map[exchange + "_MIN60"] = conn;
m_sqlite_connection_map[exchange + "_HOUR2"] = conn;
}
}
} catch (...) {
HKU_ERROR("Can't open sqlite file: {}", db_filename);
}
}
return true;
}
string SQLiteKDataDriver::_getTableName(const string&, const string& code, KQuery::KType) {
string table = fmt::format("`{}`", code);
to_lower(table);
return table;
}
KRecordList SQLiteKDataDriver::getKRecordList(const string& market, const string& code,
const KQuery& query) {
KRecordList result;
KQuery::KType ktype = query.kType();
if (query.queryType() == KQuery::INDEX) {
if (!isBaseKType(ktype)) {
KQuery::KType base_ktype = getBaseKType(ktype);
int64_t start, end, num;
start = query.start();
end = query.end();
num = end - start;
int32_t multiplier = KQuery::getKTypeInMin(ktype) / KQuery::getKTypeInMin(base_ktype);
end = getCount(market, code, base_ktype);
start = end - num * multiplier;
HKU_ERROR_IF(start < 0, "Invalid start index: {}", start);
result = _getKRecordList(market, code, ktype, start, end);
} else {
result = _getKRecordList(market, code, ktype, query.start(), query.end());
}
} else {
result = _getKRecordList(market, code, ktype, query.startDatetime(), query.endDatetime());
}
if (isBaseKType(ktype))
return result;
HKU_ERROR_IF_RETURN(!m_ifConvert, KRecordList(), "KData: unsupported ktype {}", ktype);
KQuery::KType base_ktype = getBaseKType(ktype);
return convertToNewInterval(result, base_ktype, ktype);
}
KRecordList SQLiteKDataDriver::_getKRecordList(const string& market, const string& code,
KQuery::KType kType, size_t start_ix,
size_t end_ix) {
KRecordList result;
HKU_IF_RETURN(start_ix >= end_ix, result);
string key(format("{}_{}", market, kType));
SQLiteConnectPtr connection = m_sqlite_connection_map[key];
HKU_IF_RETURN(!connection, result);
try {
KRecordTable r(market, code, kType);
SQLStatementPtr st = connection->getStatement(fmt::format(
"{} order by date limit {}, {}", r.getSelectSQLNoDB(), start_ix, end_ix - start_ix));
st->exec();
while (st->moveNext()) {
KRecordTable record;
try {
record.load(st);
KRecord k;
k.datetime = record.date();
k.openPrice = record.open();
k.highPrice = record.high();
k.lowPrice = record.low();
k.closePrice = record.close();
k.transAmount = record.amount();
k.transCount = record.count();
result.push_back(k);
} catch (...) {
HKU_ERROR("Failed get record: {}", record.str());
}
}
} catch (...) {
// 表可能不存在
HKU_ERROR("Failed to get record by index: {}", key);
}
return result;
}
KRecordList SQLiteKDataDriver::_getKRecordList(const string& market, const string& code,
KQuery::KType kType, Datetime start_date,
Datetime end_date) {
KRecordList result;
HKU_IF_RETURN(start_date >= end_date, result);
string key(format("{}_{}", market, kType));
SQLiteConnectPtr connection = m_sqlite_connection_map[key];
HKU_IF_RETURN(!connection, result);
try {
KRecordTable r(market, code, kType);
SQLStatementPtr st = connection->getStatement(
fmt::format("{} where date >= {} and date < {} order by date", r.getSelectSQLNoDB(),
start_date.number(), end_date.number()));
st->exec();
while (st->moveNext()) {
KRecordTable record;
try {
record.load(st);
KRecord k;
k.datetime = record.date();
k.openPrice = record.open();
k.highPrice = record.high();
k.lowPrice = record.low();
k.closePrice = record.close();
k.transAmount = record.amount();
k.transCount = record.count();
result.push_back(k);
} catch (...) {
HKU_ERROR("Failed get record: {}", record.str());
}
}
} catch (...) {
// 表可能不存在
HKU_ERROR("Failed to get record by date: {}", key);
}
return result;
}
size_t SQLiteKDataDriver::getCount(const string& market, const string& code, KQuery::KType kType) {
string key(format("{}_{}", market, kType));
SQLiteConnectPtr connection = m_sqlite_connection_map[key];
HKU_IF_RETURN(!connection, 0);
size_t result = 0;
result = connection->queryInt(
fmt::format("select count(1) from {}", _getTableName(market, code, kType)));
if (isBaseKType(kType))
return result;
HKU_ERROR_IF_RETURN(!m_ifConvert, 0, "KData: unsupported ktype {}", kType);
auto old_intervals_per_new_candle =
KQuery::getKTypeInMin(kType) / KQuery::getKTypeInMin(getBaseKType(kType));
return result / old_intervals_per_new_candle;
}
bool SQLiteKDataDriver::getIndexRangeByDate(const string& market, const string& code,
const KQuery& query, size_t& out_start,
size_t& out_end) {
out_start = 0;
out_end = 0;
HKU_ERROR_IF_RETURN(query.queryType() != KQuery::DATE, false, "queryType must be KQuery::DATE");
HKU_IF_RETURN(
query.startDatetime() >= query.endDatetime() || query.startDatetime() > (Datetime::max)(),
false);
string key(format("{}_{}", market, query.kType()));
SQLiteConnectPtr connection = m_sqlite_connection_map[key];
HKU_IF_RETURN(!connection, false);
string tablename = _getTableName(market, code, query.kType());
try {
out_start = connection->queryInt(fmt::format("select count(1) from {} where date<{}",
tablename, query.startDatetime().number()));
out_end = connection->queryInt(fmt::format("select count(1) from {} where date<{}",
tablename, query.endDatetime().number()));
} catch (...) {
// 表可能不存在, 不打印异常信息
out_start = 0;
out_end = 0;
return false;
}
return true;
}
KRecordList SQLiteKDataDriver::convertToNewInterval(const KRecordList& candles,
KQuery::KType from_ktype,
KQuery::KType to_ktype) {
int32_t old_intervals_per_new_candle =
KQuery::getKTypeInMin(to_ktype) / KQuery::getKTypeInMin(from_ktype);
KRecordList result(candles.size() / old_intervals_per_new_candle);
if (result.size() == 0)
return result;
int32_t target = 0;
for (size_t x = 0, total = candles.size(); x < total; ++x) {
if (candles[x].openPrice != 0 && candles[x].highPrice != 0 && candles[x].lowPrice != 0 &&
candles[x].closePrice != 0) {
if (result[target].datetime.isNull())
result[target].datetime = candles[x].datetime;
if (result[target].openPrice == 0)
result[target].openPrice = candles[x].openPrice;
if (candles[x].highPrice > result[target].highPrice)
result[target].highPrice = candles[x].highPrice;
if ((result[target].lowPrice == 0) || (candles[x].lowPrice < result[target].lowPrice))
result[target].lowPrice = candles[x].lowPrice;
result[target].transCount += candles[x].transCount;
result[target].transAmount += candles[x].transAmount;
result[target].closePrice = candles[x].closePrice;
}
if ((x + 1) % old_intervals_per_new_candle == 0) {
if ((total - x) < old_intervals_per_new_candle + 1)
break;
target++;
}
}
return result;
}
} // namespace hku

View File

@ -0,0 +1,60 @@
/*
* SQLiteKDataDriver.h
*
* Created on: 20230914
* Author: yangrq1018
*/
#pragma once
#ifndef SQLITE_KDATA_DRIVER_H
#define SQLITE_KDATA_DRIVER_H
#include "../../../utilities/db_connect/DBConnect.h"
#include "../../../utilities/db_connect/sqlite/SQLiteConnect.h"
#include "../../KDataDriver.h"
namespace hku {
class SQLiteKDataDriver : public KDataDriver {
public:
SQLiteKDataDriver();
virtual ~SQLiteKDataDriver();
virtual KDataDriverPtr _clone() override {
return std::make_shared<SQLiteKDataDriver>();
}
virtual bool _init() override;
virtual bool isIndexFirst() override {
return false;
}
virtual bool canParallelLoad() override {
return true;
}
virtual size_t getCount(const string& market, const string& code, KQuery::KType kType) override;
virtual bool getIndexRangeByDate(const string& market, const string& code, const KQuery& query,
size_t& out_start, size_t& out_end) override;
virtual KRecordList getKRecordList(const string& market, const string& code,
const KQuery& query) override;
private:
string _getTableName(const string& market, const string& code, KQuery::KType ktype);
KRecordList _getKRecordList(const string& market, const string& code, KQuery::KType kType,
size_t start_ix, size_t end_ix);
KRecordList _getKRecordList(const string& market, const string& code, KQuery::KType ktype,
Datetime start_date, Datetime end_date);
static KRecordList convertToNewInterval(const KRecordList& candles, KQuery::KType from_type,
KQuery::KType to_ktype);
private:
unordered_map<string, SQLiteConnectPtr> m_sqlite_connection_map; // key: exchange+code
bool m_ifConvert = false;
};
} /* namespace hku */
#endif /* SQLITE_KDATA_DRIVER_H */

View File

@ -60,6 +60,10 @@ void hikyuu_init(const string& config_file_name, bool ignore_preload,
option = config.getOptionList("kdata");
for (auto iter = option->begin(); iter != option->end(); ++iter) {
if (*iter == "convert") {
kdataParam.set<bool>(*iter, config.getBool("kdata", *iter));
continue;
}
kdataParam.set<string>(*iter, config.get("kdata", *iter));
}

View File

@ -59,6 +59,8 @@ private:
sqlite3* m_db;
};
typedef shared_ptr<SQLiteConnect> SQLiteConnectPtr;
} // namespace hku
#endif /* HIYUU_DB_MANAGER_SQLITE_SQLITECONNECT_H */