hikyuu2/hikyuu_python/tools/maintain/hdf5import.py

595 lines
20 KiB
C++
Raw Normal View History

2018-10-17 02:15:56 +08:00
#!/usr/bin/python
# -*- coding: utf8 -*-
# cp936
import os.path
import struct
import sqlite3
import datetime
2018-10-27 22:59:39 +08:00
import math
2018-10-17 02:15:56 +08:00
import sys
import tables as tb
from io import SEEK_END, SEEK_SET
2018-10-21 01:56:30 +08:00
from common import *
2018-10-17 02:15:56 +08:00
2018-10-30 02:13:43 +08:00
HDF5_COMPRESS_LEVEL = 9
2018-10-17 02:15:56 +08:00
def ProgressBar(cur, total):
2018-10-27 22:59:39 +08:00
percent = '{:.0%}'.format(cur / total)
2018-10-17 02:15:56 +08:00
sys.stdout.write('\r')
sys.stdout.write("[%-50s] %s" % ('=' * int(math.floor(cur * 50 / total)),percent))
sys.stdout.flush()
class H5Record(tb.IsDescription):
datetime = tb.UInt64Col() #IGNORE:E1101
openPrice = tb.UInt32Col() #IGNORE:E1101
highPrice = tb.UInt32Col() #IGNORE:E1101
lowPrice = tb.UInt32Col() #IGNORE:E1101
closePrice = tb.UInt32Col() #IGNORE:E1101
transAmount = tb.UInt64Col() #IGNORE:E1101
transCount = tb.UInt64Col() #IGNORE:E1101
class H5Index(tb.IsDescription):
datetime = tb.UInt64Col() #IGNORE:E1101
start = tb.UInt64Col() #IGNORE:E1101
2018-10-19 01:44:04 +08:00
def create_database(connect):
2018-10-21 01:56:30 +08:00
"""创建SQLITE3数据库表"""
2018-10-19 01:44:04 +08:00
try:
cur = connect.cursor()
filename = os.getcwd() + '/sqlite_createdb.sql'
with open(filename, 'r', encoding='utf8') as sqlfile:
cur.executescript(sqlfile.read())
connect.commit()
cur.close()
except sqlite3.OperationalError:
2018-10-21 01:56:30 +08:00
print("相关数据表可能已经存在,放弃创建。如需重建,请手工删除相关数据表。")
2018-10-19 01:44:04 +08:00
except Exception as e:
raise(e)
2018-10-21 01:56:30 +08:00
2018-10-27 22:59:39 +08:00
def get_marketid(connect, market):
cur = connect.cursor()
a = cur.execute("select marketid, market from market where market='{}'".format(market))
marketid = [i for i in a]
marketid = marketid[0][0]
cur.close()
return marketid
2018-10-21 01:56:30 +08:00
def get_codepre_list(connect, marketid, quotation):
stktype_list = get_stktype_list(quotation)
sql = "select codepre, type from coderuletype " \
"where marketid={marketid} and type in {type_list}"\
.format(marketid=marketid, type_list=stktype_list)
cur = connect.cursor()
a = cur.execute(sql)
a = a.fetchall()
cur.close()
return sorted(a, key=lambda k: len(k[0]), reverse=True)
def tdx_import_stock_name_from_file(connect, filename, market, quotation=None):
2018-10-27 22:59:39 +08:00
"""更新每只股票的名称、当前是否有效性、起始日期及结束日期
:param connect: sqlite3实例
:param filename:
:param market: 'SH' | 'SZ'
:param quotation: 'stock' | 'fund' | 'bond' | None
"""
2018-10-21 01:56:30 +08:00
cur = connect.cursor()
newStockDict = {}
with open(filename, 'rb') as f:
data = f.read(50)
data = f.read(314)
while data:
a = struct.unpack('6s 17s 8s 283s', data)
stockcode = a[0].decode()
stockname = a[2].decode(encoding='gbk').encode('utf8')
pos = stockname.find(0x00)
if pos >= 0:
stockname = stockname[:pos]
newStockDict[stockcode] = stockname.decode(encoding='utf8').strip()
data = f.read(314)
a = cur.execute("select marketid from market where market = '%s'" % market.upper())
marketid = [i for i in a]
marketid = marketid[0][0]
a = cur.execute("select stockid, code, name, valid from stock where marketid = %i" % marketid)
a = a.fetchall()
oldStockDict = {}
for oldstock in a:
oldstockid, oldcode, oldname, oldvalid = oldstock[0], oldstock[1], oldstock[2], int(oldstock[3])
oldStockDict[oldcode] = oldstockid
# 新的代码表中无此股票,则置为无效
if (oldvalid == 1) and (oldcode not in newStockDict):
cur.execute("update stock set valid=0 where stockid=%i" % oldstockid)
# 股票名称发生变化,更新股票名称;如果原无效,则置为有效
if oldcode in newStockDict:
if oldname != newStockDict[oldcode]:
cur.execute("update stock set name='%s' where stockid=%i" %
(newStockDict[oldcode], oldstockid))
if oldvalid == 0:
cur.execute("update stock set valid=1, endDate=99999999 where stockid=%i" % oldstockid)
# 处理新出现的股票
codepre_list = get_codepre_list(connect, marketid, quotation)
today = datetime.date.today()
today = today.year * 10000 + today.month * 100 + today.day
count = 0
for code in newStockDict:
if code not in oldStockDict:
for codepre in codepre_list:
length = len(codepre[0])
if code[:length] == codepre[0]:
count += 1
#print(market, code, newStockDict[code], codepre)
sql = "insert into Stock(marketid, code, name, type, valid, startDate, endDate) \
values (%s, '%s', '%s', %s, %s, %s, %s)" \
% (marketid, code, newStockDict[code], codepre[1], 1, today, 99999999)
cur.execute(sql)
break
2018-10-30 02:13:43 +08:00
#print('%s新增股票数%i' % (market.upper(), count))
2018-10-21 01:56:30 +08:00
connect.commit()
cur.close()
2018-10-17 02:15:56 +08:00
2018-10-27 22:59:39 +08:00
def tdx_import_day_data_from_file(connect, filename, h5file, market, stock_record):
"""
:param connect:
:param filename:
:param h5file:
:param stock_record: (stockid, marketid, code, valid, type)
:return:
"""
add_record_count = 0
if not os.path.exists(filename):
return add_record_count
stockid, marketid, code, valid, stktype = stock_record[0], stock_record[1], stock_record[2], stock_record[3],stock_record[4]
try:
group = h5file.get_node("/", "data")
except:
group = h5file.create_group("/", "data")
tablename = market.upper() + code
try:
table = h5file.get_node(group, tablename)
except:
table = h5file.create_table(group, tablename, H5Record)
if table.nrows > 0:
startdate = table[0]['datetime']/10000
lastdatetime = table[-1]['datetime']/10000
else:
startdate = None
lastdatetime = None
update_flag = False
row = table.row
with open(filename, 'rb') as src_file:
data = src_file.read(32)
2018-10-17 02:15:56 +08:00
while data:
2018-10-27 22:59:39 +08:00
record = struct.unpack('iiiiifii', data)
if lastdatetime and record[0] <= lastdatetime:
data = src_file.read(32)
continue
if 0 not in record[1:5]:
if record[2] >= record[1] >= record[3] \
and record[2] >= record[4] >= record[3]:
row['datetime'] = record[0] * 10000
row['openPrice'] = record[1] * 10
row['highPrice'] = record[2] * 10
row['lowPrice'] = record[3] * 10
row['closePrice'] = record[4] * 10
row['transAmount'] = round(record[5] * 0.001)
if stktype == 2:
# 指数
row['transCount'] = record[6]
else:
row['transCount'] = round(record[6] * 0.01)
row.append()
add_record_count += 1
if not update_flag:
update_flag = True
data = src_file.read(32)
if update_flag:
table.flush()
if startdate is not None and valid == 0:
cur = connect.cursor()
cur.execute("update stock set valid=1, startdate=%i, enddate=%i where stockid=%i" %
(startdate, 99999999, stockid))
connect.commit()
cur.close()
return add_record_count
2018-10-17 02:15:56 +08:00
2018-10-27 22:59:39 +08:00
def tdx_import_min_data_from_file(connect, filename, h5file, market, stock_record):
add_record_count = 0
if not os.path.exists(filename):
return add_record_count
stockid, marketid, code, valid, stktype = stock_record[0], stock_record[1], stock_record[2], stock_record[3],stock_record[4]
try:
group = h5file.get_node("/", "data")
except:
group = h5file.create_group("/", "data")
tablename = market.upper() + code
try:
table = h5file.get_node(group, tablename)
except:
table = h5file.create_table(group, tablename, H5Record)
if table.nrows > 0:
2018-10-30 02:13:43 +08:00
lastdatetime = table[-1]['datetime']
2018-10-27 22:59:39 +08:00
else:
lastdatetime = None
update_flag = False
row = table.row
with open(filename, 'rb') as src_file:
def trans_date(yymm, hhmm):
tmp_date = yymm >> 11
remainder = yymm & 0x7ff
year = tmp_date + 2004
month = remainder // 100
day = remainder % 100
hh = hhmm // 60
mm = hhmm % 60
return year * 100000000 + month * 1000000 + day * 10000 + hh * 100 + mm
def get_date(pos):
src_file.seek(pos * 32, SEEK_SET)
data = src_file.read(4)
a = struct.unpack('hh', data)
return trans_date(a[0], a[1])
def find_pos():
src_file.seek(0, SEEK_END)
pos = src_file.tell()
total = pos // 32
if lastdatetime is None:
return total, 0
low, high = 0, total - 1
mid = high // 2
while mid <= high:
cur_date = get_date(low)
if cur_date > lastdatetime:
mid = low
2018-10-17 02:15:56 +08:00
break
2018-10-27 22:59:39 +08:00
cur_date = get_date(high)
if cur_date <= lastdatetime:
mid = high + 1
break
cur_date = get_date(mid)
if cur_date <= lastdatetime:
low = mid + 1
else:
high = mid - 1
mid = (low + high) // 2
return total, mid
file_total, pos = find_pos()
if pos < file_total:
src_file.seek(pos * 32, SEEK_SET)
data = src_file.read(32)
while data:
record = struct.unpack('hhfffffii', data)
if 0 not in record[2:6]:
if record[3] >= record[2] >= record[4] \
and record[3] >= record[5] >= record[4]:
row['datetime'] = trans_date(record[0], record[1])
row['openPrice'] = record[2] * 1000
row['highPrice'] = record[3] * 1000
row['lowPrice'] = record[4] * 1000
row['closePrice'] = record[5] * 1000
row['transAmount'] = round(record[6] * 0.001)
if stktype == 2:
# 指数
row['transCount'] = record[7]
else:
row['transCount'] = round(record[6] * 0.01)
row.append()
add_record_count += 1
if not update_flag:
update_flag = True
data = src_file.read(32)
if update_flag:
table.flush()
return add_record_count
def tdx_import_data(connect, market, ktype, quotation, src_dir, dest_dir, progress=ProgressBar):
"""
线
"""
add_record_count = 0
market = market.upper()
filename = "{}_{}.h5".format(market, ktype)
filename = "{}/{}".format(dest_dir, filename.lower())
2018-10-30 02:13:43 +08:00
h5file = tb.open_file(filename, "a", filters=tb.Filters(complevel=HDF5_COMPRESS_LEVEL, complib='zlib', shuffle=True))
2018-10-27 22:59:39 +08:00
if ktype.upper() == "DAY":
suffix = ".day"
func_import_from_file = tdx_import_day_data_from_file
elif ktype.upper() == "1MIN":
suffix = ".lc1"
func_import_from_file = tdx_import_min_data_from_file
elif ktype.upper() == "5MIN":
suffix = ".lc5"
func_import_from_file = tdx_import_min_data_from_file
marketid = get_marketid(connect, market)
stktype_list = get_stktype_list(quotation)
sql = "select stockid, marketid, code, valid, type from stock where marketid={} and type in {}".format(marketid, stktype_list)
cur = connect.cursor()
a = cur.execute(sql)
a = a.fetchall()
total = len(a)
for i, stock in enumerate(a):
filename = src_dir + "\\" + market.lower() + stock[2]+ suffix
#print(i,filename)
2018-10-30 02:13:43 +08:00
this_count = func_import_from_file(connect, filename, h5file, market, stock)
add_record_count += this_count
if this_count > 0:
if ktype == 'DAY':
update_hdf5_extern_data(h5file, market.upper() + stock[2], 'DAY')
elif ktype == '5MIN':
update_hdf5_extern_data(h5file, market.upper() + stock[2], '5MIN')
2018-10-27 22:59:39 +08:00
if progress:
progress(i, total)
2018-10-17 02:15:56 +08:00
connect.commit()
2018-10-27 22:59:39 +08:00
h5file.close()
return add_record_count
2018-10-17 02:15:56 +08:00
2018-10-30 02:13:43 +08:00
def update_hdf5_extern_data(h5file, tablename, data_type):
2018-10-17 02:15:56 +08:00
def getWeekDate(olddate):
y = olddate//100000000
m = olddate//1000000 - y*100
d = olddate//10000 - (y*10000+m*100)
tempdate = datetime.date(y,m,d)
#python中周一是第0天周五的第4天
2018-10-30 02:13:43 +08:00
tempweekdate = tempdate + datetime.timedelta(4-tempdate.weekday())
2018-10-17 02:15:56 +08:00
newdate = tempweekdate.year*100000000 + tempweekdate.month*1000000 + tempweekdate.day*10000
return newdate
def getMonthDate(olddate):
y = olddate//100000000
m = olddate//1000000 - y*100
import calendar
2018-10-29 02:55:39 +08:00
_, d = calendar.monthrange(y, m)
2018-10-17 02:15:56 +08:00
return(y*100000000 + m*1000000 + d*10000)
def getQuarterDate(olddate):
quarterDict={1:3,2:3,3:3,4:6,5:6,6:6,7:9,8:9,9:9,10:12,11:12,12:12}
d_dict = {3:310000, 6:300000, 9:300000, 12:310000}
y = olddate//100000000
m = olddate//1000000 - y*100
2018-10-29 02:55:39 +08:00
new_m = quarterDict[m]
return( y*100000000 + new_m*1000000 + d_dict[new_m])
2018-10-17 02:15:56 +08:00
def getHalfyearDate(olddate):
y = olddate//100000000
m = olddate//1000000 - y*100
2018-10-30 02:13:43 +08:00
return y*100000000 + (6300000 if m < 7 else 12310000)
2018-10-17 02:15:56 +08:00
def getYearDate(olddate):
y = olddate//100000000
2018-10-30 02:13:43 +08:00
return(y*100000000 + 12310000)
2018-10-17 02:15:56 +08:00
def getMin60Date(olddate):
mint = olddate-olddate//10000*10000
if mint<=1030:
newdate = olddate//10000*10000 + 1030
elif mint<=1130:
newdate = olddate//10000*10000 + 1130
elif mint<=1400:
newdate = olddate//10000*10000 + 1400
else:
newdate = olddate//10000*10000 + 1500
return newdate
def getMin15Date(olddate):
mint = olddate-olddate//10000*10000
if mint<=945:
newdate = olddate//10000*10000 + 945
elif mint<=1000:
newdate = olddate//10000*10000 + 1000
elif mint<=1015:
newdate = olddate//10000*10000 + 1015
elif mint<=1030:
newdate = olddate//10000*10000 + 1030
elif mint<=1045:
newdate = olddate//10000*10000 + 1045
elif mint<=1100:
newdate = olddate//10000*10000 + 1100
elif mint<=1115:
newdate = olddate//10000*10000 + 1115
elif mint<=1130:
newdate = olddate//10000*10000 + 1130
elif mint<=1315:
newdate = olddate//10000*10000 + 1315
elif mint<=1330:
newdate = olddate//10000*10000 + 1330
elif mint<=1345:
newdate = olddate//10000*10000 + 1345
elif mint<=1400:
newdate = olddate//10000*10000 + 1400
elif mint<=1415:
newdate = olddate//10000*10000 + 1415
elif mint<=1430:
newdate = olddate//10000*10000 + 1430
elif mint<=1445:
newdate = olddate//10000*10000 + 1445
else:
newdate = olddate//10000*10000 + 1500
return newdate
def getMin30Date(olddate):
mint = olddate-olddate//10000*10000
if mint<=1000:
newdate = olddate//10000*10000 + 1000
elif mint<=1030:
newdate = olddate//10000*10000 + 1030
elif mint<=1100:
newdate = olddate//10000*10000 + 1100
elif mint<=1130:
newdate = olddate//10000*10000 + 1130
elif mint<=1330:
newdate = olddate//10000*10000 + 1330
elif mint<=1400:
newdate = olddate//10000*10000 + 1400
elif mint<=1430:
newdate = olddate//10000*10000 + 1430
else:
newdate = olddate//10000*10000 + 1500
return newdate
def getNewDate(index_type, olddate):
if index_type == 'week':
return getWeekDate(olddate)
elif index_type == 'month':
return getMonthDate(olddate)
elif index_type == 'quarter':
return getQuarterDate(olddate)
elif index_type == 'halfyear':
return getHalfyearDate(olddate)
elif index_type == 'year':
return getYearDate(olddate)
elif index_type == 'min15':
return getMin15Date(olddate)
elif index_type == 'min30':
return getMin30Date(olddate)
elif index_type == 'min60':
return getMin60Date(olddate)
else:
return None
2018-10-29 02:55:39 +08:00
if data_type == 'DAY':
2018-10-17 02:15:56 +08:00
index_list = ('week', 'month', 'quarter', 'halfyear', 'year')
else:
index_list = ('min15', 'min30', 'min60')
groupDict = {}
for index_type in index_list:
try:
groupDict[index_type] = h5file.get_node("/", index_type)
except:
groupDict[index_type] = h5file.create_group("/", index_type)
2018-10-30 02:13:43 +08:00
try:
table = h5file.get_node("/data", tablename)
except:
return
for index_type in index_list:
try:
index_table = h5file.get_node(groupDict[index_type],tablename)
except:
index_table = h5file.create_table(groupDict[index_type],tablename, H5Index)
total = table.nrows
if 0 == total:
continue
index_total = index_table.nrows
index_row = index_table.row
if index_total:
index_last_date = int(index_table[-1]['datetime'])
last_date = getNewDate(index_type, int(table[-1]['datetime']))
if index_last_date == last_date:
2018-10-17 02:15:56 +08:00
continue
2018-10-30 02:13:43 +08:00
startix = int(index_table[-1]['start'])
pre_index_date = int(index_table[-1]['datetime'])
else:
startix = 0
date = int(table[0]['datetime'])
pre_index_date = getNewDate(index_type,date)
index_row['datetime'] = pre_index_date
index_row['start'] = 0
index_row.append()
index = startix
for row in table[startix:]:
date = int(row['datetime'])
cur_index_date = getNewDate(index_type, date)
if cur_index_date != pre_index_date:
index_row['datetime'] = cur_index_date
index_row['start'] = index
2018-10-17 02:15:56 +08:00
index_row.append()
2018-10-30 02:13:43 +08:00
pre_index_date = cur_index_date
index += 1
index_table.flush()
2018-10-17 02:15:56 +08:00
2018-10-30 02:13:43 +08:00
#h5file.close()
2018-10-29 02:55:39 +08:00
#print('\n')
2018-10-17 02:15:56 +08:00
if __name__ == '__main__':
import time
starttime = time.time()
src_dir = "D:\\TdxW_HuaTai"
dest_dir = "c:\\stock"
connect = sqlite3.connect(dest_dir + "\\hikyuu-stock.db")
2018-10-19 01:44:04 +08:00
create_database(connect)
2018-10-21 01:56:30 +08:00
2018-10-30 02:13:43 +08:00
print("导入股票代码表")
tdx_import_stock_name_from_file(connect, src_dir + "\\T0002\\hq_cache\\shm.tnf", 'SH', 'stock')
tdx_import_stock_name_from_file(connect, src_dir + "\\T0002\\hq_cache\\szm.tnf", 'SZ', 'stock')
2018-10-27 22:59:39 +08:00
2018-10-30 02:13:43 +08:00
add_count = 0
2018-10-27 22:59:39 +08:00
#add_count = tdx_import_data(connect, 'SH', 'DAY', 'stock', src_dir + "\\vipdoc\\sh\\lday", dest_dir)
2018-10-30 02:13:43 +08:00
#add_count = tdx_import_data(connect, 'SZ', 'DAY', 'stock', src_dir + "\\vipdoc\\sz\\lday", dest_dir)
print("\n导入上证5分钟数据")
#add_count = tdx_import_data(connect, 'SH', '5MIN', 'stock', src_dir + "\\vipdoc\\sh\\fzline", dest_dir)
print("\n导入数量:", add_count)
print("\n导入上证1分钟数据")
add_count = tdx_import_data(connect, 'SH', '1MIN', 'stock', src_dir + "\\vipdoc\\sh\\minline", dest_dir)
print("\n导入数量:", add_count)
2018-10-17 02:15:56 +08:00
connect.close()
endtime = time.time()
print("\nTotal time:")
print("%.2fs" % (endtime-starttime))
print("%.2fm" % ((endtime-starttime)/60))