简介:
处理过程大体如下:
1、使用tushare获取stock信息
2、对数据进行处理,做好train_x和train_y的对应关系
3、训练和预测 ,网络是学习x也y之间的映射关系,网络要与数据匹配
1、get_info.py 用于从获取stock信息,目前只有日线数据,后面会增加(tushare积分不够,有些数据获取不到,大家注册下给点积分)
2、stock_sql.py 用于将部分信息记录到数据库,方便查询检索。数据库使用的sqlite3
3、prepare.py 用于对数据进行处理,生成train_x,train_y的对应关系,满足网络训练需要。NN可以学习映射规律,进行预测。
4、trainning.py 网络模型,并进行训练。
5、evaluate.py 对模型进行简单的评估。
6、server.py flask做的后台用于数据展示。
代码实现:
get_info.py
from concurrent.futures import ProcessPoolExecutor, as_completed
from pandas import DataFrame
import stock_sql
import pandas as pd
import numpy as np
import datetime
import time
import tushare as ts
import jieba
import re
from gensim.test.utils import common_texts, get_tmpfile
from gensim.models import Word2Vec,word2vec
from gensim.models.doc2vec import Doc2Vec, TaggedDocument
from sklearn.preprocessing import OneHotEncoder
#path = get_tmpfile("word2vec.model")
print(ts.__version__)
ts.set_token("d1af48f518c17415b1b98b2ce84ab7b1a0025adfdde78e22513b31ec")
pro = ts.pro_api()
d = {
"ts_code": "TS代码",
"symbol": "股票代码",
"name": "股票名称",
"area": "所在地域",
"industry": "所属行业",
"fullname": "股票全称",
"enname": "英文全称",
"market": "市场类型(主板/中小板/创业板/科创板)",
"exchange": "交易所代码",
"curr_type": "交易货币",
"list_status": "上市状态:L上市 D退市 P暂停上市",
"list_date": "上市日期",
"delist_date": "退市日期",
"is_hs": "是否沪深港通标的,N否 H沪股通 S深股通",
}
daily = {
"ts_code": "股票代码",
"trade_date": "交易日期",
"open": "开盘价",
"high": "最高价",
"low": "最低价",
"close": "收盘价",
"pre_close": "昨收价",
"change": "涨跌额",
"pct_chg": "涨跌幅(未复权)",
"vol": "成交量(手)",
"amount": "成交额(千元)",
}
def dateRange(beginDate, endDate):
dates = []
dt = datetime.datetime.strptime(beginDate, "%Y%m%d")
date = beginDate[:]
while date <= endDate:
dates.append(date)
dt = dt + datetime.timedelta(1)
date = dt.strftime("%Y%m%d")
return dates
class stock_basic:
"""
股票基本信息
"""
basic_info = {
"ts_code": "TS代码",
"symbol": "股票代码",
"name": "股票名称",
"area": "所在地域",
"industry": "所属行业",
"fullname": "股票全称",
"enname": "英文全称",
"market": "市场类型(主板/中小板/创业板/科创板)",
"exchange": "交易所代码",
"curr_type": "交易货币",
"list_status": "上市状态:L上市 D退市 P暂停上市",
"list_date": "上市日期",
"delist_date": "退市日期",
"is_hs": "是否沪深港通标的,N否 H沪股通 S深股通",
}
def __init__(self):
self.file_path ="data/stock_basic.csv"
self.record_path= "data/data.csv"
def getBasicInfo(self):
# 获取信息
fields = list(self.basic_info.keys())
headers = list(self.basic_info.values())
data: DataFrame = pro.query('stock_basic', exchange='',
list_status='L', fields=fields)
# print(data.values)
data.columns = headers
data.to_csv(self.file_path, header=True, encoding='utf-8', index=False)
def recordDateInfo(self):
# 初始化跟新时间
data = pd.read_csv("data/stock_basic.csv", dtype=object)
# a 股票的全部信息 ,b 要去掉的几列
a = list(d.values())
b = ["TS代码", "股票名称", "上市日期", "退市日期"]
today = datetime.date.today().strftime("%Y%m%d")
header = [val for val in a if val not in b]
print(header)
data_1 = data.drop(header, axis=1)
data_1.insert(loc=len(data_1.columns), column="更新时间",
value=today, allow_duplicates=False)
for row in data_1.values:
row[4] = row[2]
data_1.to_csv("data/data.csv", header=True, encoding='utf-8', index=False)
def getExVec(self):
df = stock_sql.getStockFrame(headers=["symbol", "name", "area", "industry"])
df.sort_values(by=["symbol"], inplace=True, ignore_index=True)
df.fillna(" ",inplace=True)
area_vec = OneHotEncoder(sparse=False).fit_transform(df[["area"]])
indu_vec = OneHotEncoder(sparse=False).fit_transform(df[["industry"]])
# data = pd.read_csv("data/stock_basic.csv", dtype=object)
# aa = OneHotEncoder().fit_transform(data[["所在地域"]])
data_area= {}
data_indu={}
symbols = np.array(df[["symbol"]])
# areas =np.array(df[["area"]]).reshape(-1)
# data_t={}
for i in range(len(symbols)):
data_area[symbols[i][0]] = area_vec[i]
data_indu[symbols[i][0]] = indu_vec[i]
return data_area, data_indu, area_vec.shape,indu_vec.shape
class stock_daily(object):
def __init__(self):
pass
def updateDailyInfo(self):
"""
更新所有 stock 的日信息
:return:
"""
df = stock_sql.getTradeDateList()
# 按照symbol 升级
# df.sort_values(by=["symbol"], inplace=True, ignore_index=True)
stock_list = np.array(df)
self. updateDailyList(stock_list)
# print(stock_list)
def createDateTable(self):
"""
创建数据库用来记录 更新的日期信息 ,并进行根据 stockBasic信息进行初始化
:return:
"""
stock_sql.createDateTable()
stock_sql.initDateTable()
def getDailyInfo(self, symbol, ts_code, start_time:str, end_time:str) -> pd.DataFrame:
"""
获取stock 日信息
:param symbol: stock 代码
:param ts_code:
:param start_time: 起始日期
:param end_time: 结束日期
:return: daily 信息
"""
df_list = []
while start_time != end_time:
df_t = pro.daily(
ts_code=ts_code, start_date=start_time, end_date=end_time)
df_list.append(df_t)
end_time = df_t.tail(1)["trade_date"].values[0]
if len(df_list) != 0:
data = pd.concat(df_list, join="inner")
else:
data = None
return data
def dailyInfoConcat(self, symbol: int, data:pd.DataFrame):
"""
将新增的日新增信息与原记录的信息合并保存
:param symbol: stock 代码
:param data: 新增数据
:return: 最新更新日期
"""
path = stock_sql.getFilePath(symbol)
df = stock_sql.getStockData(path)
data.trade_date = data.trade_date.astype("int64")
# 新数据放在前
data_n = pd.concat([data, df], join='inner', ignore_index=True)
# 1.数据去重
data_n.drop_duplicates(subset=["trade_date"], keep='first', inplace=True)
data_n.to_csv(path, header=True, encoding='utf-8', index=False)
return data_n.head(1)["trade_date"].values[0]
def recordLastDate(self,symbol,data):
"""
记录最后dailyinfo的最后更新日期
:param symbol: stock 代码
:param data: 最后更新日期
:return: none
"""
stock_sql.updateTradeDate(symbol, data)
def updateDailyList(self, stockList):
"""
更新stocklist中的stock 日信息
:param stockList:输入的股票列表包含 symbol name tscode lastdate
:return:
"""
today = datetime.date.today().strftime("%Y%m%d")
for item in stockList:
# time.sleep(.3)
print(item)
data = self.getDailyInfo(item[0], item[2], item[3], "20200312")
if data is not None:
last_d = self.dailyInfoConcat(item[0], data)
self.recordLastDate(item[0], last_d)
def updateDailyListMutliple(self,stockList):
"""
更新stocklist中的stock 日信息 (多线程)
:param stockList:输入的股票列表包含 symbol name tscode lastdate
:return:
"""
#对列表进行拆分
stock_split = stock_sql.arr_split( stockList ,500)
executor = ProcessPoolExecutor(max_workers=8)
all_task = [executor.submit( self.updateDailyList, list_s) for list_s in stock_split]
for future in as_completed(all_task) :
res = future.result()
class stock_company(object):
"""
股票相关行业信息
"""
def __init__(self):
self.file_path = "data/stock_company.csv"
self.company_info = {
"ts_code": "股票代码",
"province": "所在省份",
"city": "所在城市",
"main_business": "主营业务",
"business_scope": "经营范围"
}
def getCompanyInfo(self):
df:DataFrame =pro.stock_company(fields = list( self.company_info.keys()))
df.to_csv(self.file_path,header=True, encoding='utf-8', index=False)
print(df)
def getCompanyData(self):
drops = ["ts_code"]
columns =[x for x in self.company_info.keys() if x not in drops ]
df :DataFrame = pd.read_csv(self.file_path,index_col=False,na_values="")
df.sort_values(by=["ts_code"], ascending=True, ignore_index=True, inplace=True)
df = df[columns]
return df
def removePunctuation(self,text):
text = re.sub(r'[{}]+'.format('!,;:?()[]\n,。():;、\''), ' ', text)
return text.strip().lower()
def word_cut_test(self):
sentences = ["吸收公共存款",
"2015年我毕业于西安科技大学",
"2015年我毕业于西安电子科技大学",
"2015年我毕业于西安建筑科技大学",
"2015年我毕业于西安交通大学",
"2015年我毕业于北京大学"]
for sentence in sentences:
# 全模式
words = jieba.cut(sentence, cut_all=True)
print("全模式: %s" % " ".join(words))
words = jieba.cut(sentence, use_paddle=True)
print("新词模式: %s" % " ".join(words))
# 默认精确模式
words = jieba.cut(sentence)
print("精确模式: %s" % " ".join(words))
# 搜索模式
words = jieba.cut_for_search(sentence)
print("搜索模式: %s" % " ".join(words))
def word_cut(self):
jieba.enable_paddle()
df =self.getCompanyData()
df =df[["business_scope"]]
arr = np.array(df)
stopwords = self.stopwordslist("nlp/stop_words.txt")
with open("nlp/cut_words.txt","w+", encoding='utf-8') as fw:
for i in range(len(arr)):
list=self.removePunctuation(arr[i][0])
item = jieba.cut(list,use_paddle=True) # 使用paddle模式
santi_words =[x for x in item if len(x) > 1 and x not in stopwords]
fw.writelines(santi_words)
fw.write("\r\n")
print(santi_words)
def stopwordslist(self,filepath):
stopwords = [line.strip() for line in open(filepath, 'r', encoding='utf-8').readlines()]
return stopwords
def word2vec(self):
sentences = word2vec.PathLineSentences("nlp/cut_words.txt")
model = Word2Vec(sentences, size=20, window=5, min_count=1, workers=4)
model.save("nlp/word2vec.model")
model = Word2Vec.load("nlp/word2vec.model")
# a= model.train([["吸收公众存款", "吸收公众存款"]], total_examples=1, epochs=1)
vector = model.wv['新材料']
a=model.similar_by_vector(vector)
print(a)
print(vector)
def doc2vec(self):
sentences = word2vec.PathLineSentences("nlp/cut_words.txt")
documents = [TaggedDocument(doc, [i]) for i, doc in enumerate(sentences)]
# for i ,doc in enumerate(sentences):
# ddd = TaggedDocument(doc,[i])
# print(ddd)
model = Doc2Vec(documents, vector_size=20, window=2, min_count=1, workers=4)
model.save("nlp/doc2vec.model")
model = Doc2Vec.load("nlp/doc2vec.model")
vector = model.infer_vector(["电器开关零部件及附件制造"])
model.similar_by_vector(vector)
pass
if __name__ == "__main__":
# stock_record()
# stock_data()
# company= stock_company()
# company.doc2vec()
basic =stock_basic()
basic.getExVec()
stock_sql.py
from datetime import datetime
import pandas as pd
import numpy as np
import sqlite3
import os
import pickle
d = {"ts_code": "TS代码",
"symbol": "股票代码",
"name": "股票名称",
"area": "所在地域",
"industry": "所属行业",
"fullname": "股票全称",
"enname": "英文全称",
"market": "市场类型(主板/中小板/创业板/科创板)",
"exchange": "交易所代码",
"curr_type": "交易货币",
"list_status": "上市状态:L上市 D退市 P暂停上市",
"list_date": "上市日期",
"delist_date": "退市日期",
"is_hs": "是否沪深港通标的,N否 H沪股通 S深股通",
}
DataBase = "data\my_stock.db"
RECORDTABLE = ''
class my_sql:
def __init__(self, baseName):
self.conn = sqlite3.connect(baseName)
def __enter__(self):
# print('__enter__() is call!')
return self.conn.cursor()
def __exit__(self, exc_type, exc_value, traceback):
self.conn.commit()
self.conn.close()
# print('__exit__() is call!')
# print(f'type:{exc_type} : ' + f'value:{exc_value}')
# print(f'value:{exc_value}')
# print(f'trace:{traceback}')
return True # 异常不抛出
def CreateStockTabel():
Sql_str = 'CREATE TABLE stock_info ('
for item in list(d.keys()):
Sql_str += item + " CHAR(50), "
Sql_str = Sql_str[:-2]
Sql_str += ") ;"
print(Sql_str)
with my_sql(DataBase) as c:
c.execute(Sql_str)
def createRecordTable():
Sql_str = 'CREATE TABLE IF NOT EXISTS stock_record (' \
'symbol TEXT PRIMARY KEY, name TEXT, time TEXT, real_data BLOB, predict_data BLOB )'
with my_sql(DataBase) as c:
c.execute(Sql_str)
def createDateTable():
"""
创建stock_date 数据表 用来记录daily train等日期信息
:return:
"""
Sql_str = 'CREATE TABLE IF NOT EXISTS stock_date (' \
'symbol TEXT PRIMARY KEY, name TEXT, ts_code TEXT, ' \
'list_date TEXT, trade_date TEXT, train_date TEXT)'
with my_sql(DataBase) as c:
c.execute(Sql_str)
def saveStockData():
"""
将stock的基本信息存放再数据库
:return:
"""
df = pd.read_csv("data/stock_basic.csv", dtype=object)
df.columns = list(d.keys())
conn = sqlite3.connect(DataBase)
df.to_sql(name="stock_info", con=conn, index=False, if_exists="replace")
conn.close()
def initDateTable():
"""
生成 表用来记录 stock信息更新到哪一天, 模型对训练到哪一天
:return:
"""
sqlStr = "SELECT symbol, name, ts_code,list_date FROM stock_info "
conn = sqlite3.connect(DataBase)
df = pd.read_sql(sqlStr, con=conn)
conn.close()
with my_sql(DataBase) as c:
for index, row in df.iterrows():
sqlStr = "REPLACE INTO stock_date (symbol, name, ts_code, list_date, info_date, train_date) " \
"VALUES (?,?,?,?,?,?)"
c.execute(sqlStr,(row["symbol"],row["name"],row["ts_code"],
row["list_date"],row["list_date"],row["list_date"]))
def saveStockPath():
"""
将stock的位置信息存入数据库
:return:
"""
sqlStr = "SElECT name, symbol, ts_code FROM stock_info "
conn = sqlite3.connect(DataBase)
df = pd.read_sql(sqlStr, con=conn)
paths = []
for item in df.values:
path_o = "".join(["data/info/", item[0].replace("*", ""), "-", item[2], ".csv"])
path_w = "".join(["data/info/", item[1], ".csv"])
paths.append(path_w)
# os.rename(path_o,path_w)
df.drop(columns=["ts_code"], inplace=True)
df.insert(loc=2, column="path", value=paths)
df.to_sql(name="stock_path", con=conn, index=False, if_exists='replace')
conn.close()
print(df)
def initTrainDate():
"""
初始化 训练日期
:param symbol:
:param date:
:return:
"""
sqlStr = "UPDATE stock_date SET train_date = list_date"
with my_sql(DataBase) as c:
c.execute(sqlStr)
print("initTrainDate")
def updateTradeDate(symbol,date):
"""
更新stock的 日期信息 记录日线更新到哪那一天
:param symbol: stock 代码
:param date: 交易日期
:return:
"""
sqlStr = "UPDATE stock_date SET trade_date ='{}' WHERE symbol='{}'".format(date, symbol)
with my_sql(DataBase) as c:
c.execute(sqlStr)
def updateTrianList(stockList):
"""
更新股票的 训练日期
:param dateList:输入待跟新的stock 列表 包含 symbol, name
:return:
"""
with my_sql(DataBase) as c:
for item in stockList: # symbol date tscode
sqlStr = "UPDATE stock_date SET train_date = trade_date WHERE symbol='{}'".format(item[0])
c.execute(sqlStr)
def getPathList(symbolList):
"""
获取stock 信息的文件位置列表
:param symbolList: 输入列表包含可 symbol和 name 信息
:return: pathList 输出列表 包含 path 和 symbol
"""
pathList=[]
with my_sql(DataBase) as c:
for item in symbolList: #symbol :(symbol,name)
sqlStr = "SELECT path FROM stock_path WHERE symbol ='{}'".format(item[0])
cursor = c.execute(sqlStr)
path = cursor.fetchone()[0]
pathList.append((path,item[0]))
return pathList #(path,symbol)
def getFilePath(symbol):
sqlStr = "SELECT path FROM stock_path WHERE symbol = \"" + symbol + "\""
with my_sql(DataBase) as c:
cursor = c.execute(sqlStr)
path = cursor.fetchone()[0]
return path
def getLastTradeList(pathList):
"""
获取stock 最后更新信息
:param pathList:输出列表 包含 path 和 symbol
:return:dateList 输出列表 包含 symbol date tscode)
"""
dateList =[] # (symbol date tscode)
for item in pathList: #item : (path, symbol)
df = pd.read_csv(item[0], dtype=object)
df.sort_values(by=["trade_date"], inplace=True, ignore_index=True)
date = df.tail(1)[["trade_date", "ts_code"]].values[0]
dateList.append((item[1],date[0],date[1]))
return dateList
def getLastDate(symbol):
"""
获取stock的训练数据的最后日期:
:param symbol: stock的symbol信息
:return: : trade_date,train_date
"""
columns = ["symbol", "trade_date", "train_date"]
sqlStr = "SELECT {} From stock_date WHERE symbol= '{}'".format(columnsToSql(columns),symbol)
with my_sql(DataBase) as c:
cursor= c.execute(sqlStr)
data = cursor.fetchone()
return data[1],data[2]
def getLastTradeDate(symbol):
"""
获取stock的训练数据的最后日期:
:param symbol: stock的symbol信息
:return: nparray: [date,tscode]
"""
path =getFilePath(symbol)
df =pd.read_csv(path,dtype=object)
df.sort_values(by=["trade_date"],inplace=True,ignore_index=True)
date = df.tail(1)[["trade_date","ts_code"]]
return date.values[0]
def getTradeDateList():
sqlStr= "SELECT symbol, name, ts_code,trade_date FROM stock_date"
conn =sqlite3.connect(DataBase)
df= pd.read_sql(sqlStr,con=conn)
conn.close();
return df
def getStockFrame(headers = ["symbol","name"]):
sqlStr = "SELECT {} FROM stock_info".format(columnsToSql(headers))
conn = sqlite3.connect(DataBase)
df = pd.read_sql(sqlStr, con=conn, columns=headers)
conn.close();
return df
def getStockInfo(item):
if not item:
return pd.DataFrame()
columns = ["symbol", "area", "industry", "market", "list_date", "is_hs"]
sqlStr = "SELECT {} From stock_info WHERE symbol= '{}'".format(columnsToSql(columns),item["symbol"])
conn = sqlite3.connect(DataBase)
df = pd.read_sql(sqlStr, con=conn)
conn.close()
df.fillna("", inplace=True)
if (df.values.shape[0] != 1):
return pd.DataFrame()
for index, key, in enumerate(columns):
name, value = parseStockInfo(key, df.values[0][index])
print(name, value)
df.values[0][index] = name + ": " + value
return df
def getDailyFrame(symbol:str)->pd.DataFrame:
path = getFilePath(symbol)
df = getStockData(path)
# 2.将nan数据 替换为0
df.fillna(0, inplace=True)
# 3.按照日期排序 ,并忽略索引
df.sort_values(by=["trade_date"], ascending=True, ignore_index=True, inplace=True)
return df
def getStockData(path):
data = pd.read_csv(path, index_col=False, na_values=0)
# 1.数据去重
data.drop_duplicates(subset=["trade_date"], keep='first', inplace=True)
return data
def getRecordData(symbol):
if not symbol :
return None
now_date = datetime.now().strftime("%Y%m%d")
sqlStr = "SELECT real_data, predict_data, symbol FROM stock_record" \
" WHERE symbol='{}' AND time='{}'".format(symbol,now_date)
with my_sql(DataBase) as c:
cursor = c.execute(sqlStr)
data = cursor.fetchone()
if data :
real, predict = pickle.loads(data[0]),pickle.loads(data[1])
return (real,predict)
def SaveRecordData(stockInfo, realData, predictData):
reald, pred = pickle.dumps(realData), pickle.dumps(predictData)
now_date = datetime.now().strftime("%Y%m%d")
sqlStr = "REPLACE INTO stock_record (symbol,name,time,real_data, predict_data) VALUES (?,?,?,?,?) "
with my_sql(DataBase) as c:
c.execute(sqlStr, (stockInfo["symbol"], "", now_date, reald, pred))
#print(now_date)
pass
def ClearRecordData():
sqlStr ="DROP TABLE stock_record"
with my_sql(DataBase) as c:
c.execute(sqlStr)
def parseStockInfo(key, value):
name = d[key]
if key == 'market':
name = "市场类型"
elif key == "curr_type":
if value == "CNY":
value = "RMB"
elif key == "list_status":
name = "上市状态"
if value == "L":
value = "上市"
elif value == "D":
value = "退市"
elif value == "P":
value = "暂停上市"
elif key == "is_hs":
name = "沪深港通"
if value == "S":
value = "深港通"
elif value == "H":
value = "沪港通"
elif value == "N":
value = "否"
return name, value
def columnsToSql(columns):
"""
把列表转换成对应的sql 列字符串
:param columns:
:return:
"""
columnStr = ""
for x in columns:
columnStr += x + ","
return columnStr[:-1]
def arr_split(arr,size):
s = [arr[i:i+size] for i in range(0,int(len(arr))+1,size)]
return s
def DataBaseInit():
# CreateStockTabel()
saveStockData()
def sqlTest():
con = sqlite3.connect("data\my_stock.db")
sql = 'select * from user_information LIMIT 3'
df = pd.read_sql(sql, con)
print(df)
if __name__ == '__main__':
# DataBaseInit()
# getStockList()
# saveStockPath()
initTrainDate()
#getFilePath("000001")
prepare.py
# -*-coding:utf-8 -*-
import sys
#print(sys.executable)
import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
import stock_sql
from abc import ABCMeta, abstractmethod
from multiprocessing import cpu_count
from concurrent.futures import ProcessPoolExecutor,as_completed
import time
import get_info
import datetime
#pd.show_versions()
daily = {
"ts_code": "股票代码",
"trade_date": "交易日期",
"open": "开盘价",
"high": "最高价",
"low": "最低价",
"close": "收盘价",
"pre_close": "昨收价",
"change": "涨跌额",
"pct_chg": "涨跌幅(未复权)",
"vol": "成交量(手)",
"amount": "成交额(千元)",
}
class stockInfo:
def __init__(self,symbol:str ="000001",name:str="平安银行"):
self.symbol=symbol
self.name=name
SEQUENCE_LEN=20
RREDICT_LEN=30
def parseUntrainedData(symbol,df):
"""
解析未训练用的数据
:return:
"""
index =0
dateInfo =stock_sql.getLastDate(symbol)
lastDate = int(dateInfo[1])
for i,row in df.iterrows():
if row["trade_date"] >= lastDate:
index = i
#df= df.iloc[:index+ SEQUENCE_LEN]
break
if index < SEQUENCE_LEN -1 : #数据从未训练过 应该从0开始训练
index =0
elif SEQUENCE_LEN -1 <= index: # 数据训练过:
index = index+1 -(SEQUENCE_LEN -1)
return index
def getTrianColnum(stockinfo:stockInfo,length=RREDICT_LEN):
symbol = stockinfo.symbol
df = stock_sql.getDailyFrame(symbol)
# 对数据进行训了测试的分离处理
x, y = prepareLogistics().dataProcess(df)
col =["close"]
close = np.array(x[col])[-length:]
return np.array( close)
def columnSplit(verify=False):
"""
将数据按照列的方式拆分成 x和y
verify :验证y 的日期是否正确 开启时 将带上y对应的日期
:return:
"""
all = list(daily.keys());
column_y=['ts_code']
column_x = [x for x in all if x not in column_y]
if verify:# 带上日期 验证 日期是否对应
column_y = ["close",'trade_date']
else:
column_y = ["close"]
return column_x ,column_y
def getPredictTrainnData( stockInfo,length =RREDICT_LEN):
#获取stock 原始信息 ,去重 、预测数据对其
symbol = stockInfo["symbol"]
df = stock_sql.getDailyFrame(symbol)
column_x, column_y = columnSplit()
x = np.array(df[column_x])
x= np.around(x,decimals=2)
return x[-length:]
def dataSplit(npArr, split=0.9):
"""
:param npArr:
:param split:
:return:
"""
#对数据进行拆分 获得验证数据集
split_boundary = int(npArr.shape[0] * split)
train_x = npArr[: split_boundary]
test_x = npArr[split_boundary:]
return train_x, test_x
class prepareBase:
def __init__(self):
print(str( type(self) )+" __build __")
self.verify = False
self.only_untrain = False
self.batch_size = 200
self.SEQUENCE_LEN = 50
self.RREDICT_LEN=30
self.tuple_x=1
pass
def getInputShape(self):
"""
获取训练数据的shape
:return: shape
"""
column_x, _ = columnSplit()
return (self.SEQUENCE_LEN, len(column_x))
def dataProcess(self,df:pd.DataFrame)->(np.array,np.array):
"""
:param df:
:return:
"""
# 4.对数据按照表格拆分
column_x, column_y = columnSplit(self.verify)
x, y = df[column_x], df[column_y]
# 5拼凑数据 x的最后一行没有预测值 y的第一行没有 训练值
x, y = x.drop(len(x) - 1, axis=0), y.drop(0, axis=0)
y.reset_index(drop=True, inplace=True)
return x, y
def dataSequence(self,x:pd.DataFrame,nor:bool=True)->np.array:
"""
将dataFrame数据 组成训练序列
:param df: 原始dataFrame数据
:param nor: 是否对数据标准化
:param len: 序列长度
:return: 序列化数据以及 标准化的Scaler
"""
scaler = MinMaxScaler()
if nor :
data_all = np.array(x).astype("float64")
data_all = scaler.fit_transform(data_all)
else:
data_all = np.array(x)
data = []
for i in range(len(data_all) - self.SEQUENCE_LEN + 1):
data.append(data_all[i: i + self.SEQUENCE_LEN])
x = np.array(data).astype('float64')
return x, scaler
def dataSequence_y(self,y:pd.DataFrame,nor:bool=True)->np.array:
"""
将dataFrame数据 组成训练序列
:param df: 原始dataFrame数据
:param nor: 是否对数据标准化
:param len: 序列长度
:return: 序列化数据以及 标准化的Scaler
"""
scaler = MinMaxScaler()
if nor :
data_all = np.array(y).astype("float64")
data_all = scaler.fit_transform(data_all)
else:
data_all = np.array(y)
data = []
for i in range(len(data_all) - self.SEQUENCE_LEN + 1):
data.append(data_all[i + self.SEQUENCE_LEN - 1])
x = np.array(data)
return x, scaler
def test(self):
print("+++++++++++++++++++++++")
self.verify = True
self.only_untrain = True
df = stock_sql.getStockFrame()
df.sort_values(by=["symbol"], inplace=True, ignore_index=True)
stock_list = np.array(df)
for info in stock_list:
data = self.getTrainData(stockInfo(symbol=info[0]))
data_x, data_y = data[0], data[1]
for i in range(len(data_x)):
print(data_x[i][self.SEQUENCE_LEN - 1])
print(data_y[i])
def getTrainData(self, stockinfo: stockInfo) -> (np.array, np.array):
"""
获取单只股票的训练数据
:param stockInfo: stock 信息
:param only_untrain: 只使用未训练的数据
:return: 如果数据不够返回None
"""
symbol = stockinfo.symbol
df = stock_sql.getDailyFrame(symbol)
if len(df) <= self.SEQUENCE_LEN:
return None
# 对数据进行训了测试的分离处理
x, y = self.dataProcess(df)
train_x, scaler_x = self.dataSequence(x, nor= not self.verify)
train_y, scaler_y = self.dataSequence_y(y,nor =not self.verify)
if self.only_untrain:
index = parseUntrainedData(symbol, df)
if index >= len(train_x):
# 起始训练数据超出是长度,没有数据
return None
train_x, train_y = train_x[index:], train_y[index:]
return train_x, train_y,scaler_x,scaler_y
def getTestData(self, stockInfo:stockInfo):
data =self.getTrainData(stockInfo)
if data is not None and len(data[0])>=0:
train_x, train_y =data[0],data[1]
return train_x[-self.RREDICT_LEN:], train_y[-self.RREDICT_LEN:],data[2],data[3]
else:
return None
def dataGenerator(self, index: int = 0):
df = stock_sql.getStockFrame()
df.sort_values(by=["symbol"], inplace=True, ignore_index=True)
# 截取index以后的数据
stock_array = np.array(df)[index:]
batch_array = stock_sql.arr_split(stock_array, self.batch_size)
for s_array in batch_array:
# 多进程
list_x, list_y = self.dataConcat(s_array)
# list_x,list_y =dataConcat(s_array,only_untrain)
if len(list_x[0]) == 0 or len(list_y) == 0: continue
data_x = []
for i in range(self.tuple_x):
data_x.append( np.concatenate(list_x[i], axis=0))
data_y = np.concatenate(list_y, axis=0)
yield data_x, data_y, s_array
# yield (getTrainData({"symbol":info[0]},only_untrain),info)
def dataConcat(self,stock_list):
"""
批量处理stock list 数据 将每只处理的stock数据放到 listx 和 listy中
:param stock_list: 待处理的stock list 包含symbol name 信息
:param only_untrain: 只处理未使用的数据
:return: list_x,list_x
"""
list_x, list_y = [], []
for i in range(self.tuple_x):
list_x.append([])
for info in stock_list:
data =self.getTrainData(stockInfo(symbol=info[0]))
print("stock info :" + info + "\r\n")
if data == None: continue
for i in range(self.tuple_x):
list_x[i].append(data[0][i])
list_y.append(data[1])
return list_x, list_y
def dataConcatMultiple(self, stock_array, only_untrain=False):
list_x, list_y = [], []
for i in range(self.tuple_x):
list_x.append([])
split_array = np.array_split(stock_array, 4, axis=0)
executor = ProcessPoolExecutor(max_workers=4)
all_task = [executor.submit(self.dataConcat, list_a) for list_a in split_array]
for future in as_completed(all_task):
res = future.result()
for i in range(self.tuple_x):
list_x[i].extend(res[i][0])
list_y.extend(res[1])
return list_x, list_y
class logisticsAllScaler(prepareBase):
def __init__(self):
super(logisticsAllScaler,self).__init__()
self.SEQUENCE_LEN=50
self.list_x= None
time.sleep(10)
pass
def getDataWithIndex(self, stock_arr, only_untrain=False):
list_x, list_y, index_list = [], [], []
trian_map = {}
index = 0
for item in stock_arr:
print(item[1])
df = stock_sql.getDailyFrame(item[0])
if only_untrain:
index = parseUntrainedData(item[0], df)
# df.drop(columns=["ts_code"],inplace=True)
x, y = self.dataProcess(df)
trian_map[item[0]] = index
list_x.append(x)
list_y.append(y)
index_list.append(len(x))
return list_x, list_y, index_list, trian_map
def getDataWithIndexMultiple(self,stock_arr, only_untrain=False):
list_x, list_y, index_list = [], [], []
trian_map = {}
a_split = np.array_split(stock_arr, 4, axis=0)
executor = ProcessPoolExecutor(max_workers=4)
all_tasks = [executor.submit(self.getDataWithIndex, stockS_arr, only_untrain) for stockS_arr in a_split]
for furture in as_completed(all_tasks):
res = furture.result();
list_xt, list_yt, index_listt, trian_mapt = res[0], res[1], res[2], res[3]
list_x.extend(list_xt)
list_y.extend(list_yt)
index_list.extend(index_listt)
trian_map.update(trian_mapt)
return list_x, list_y, index_list, trian_map
def allScaler(self):
if self.list_x is not None:
return self.list_x, self.list_y, self.scaler_x, self.scaler_y, self.train_map
df = stock_sql.getStockFrame()
df.sort_values(by=["symbol"], inplace=True, ignore_index=True)
arr = np.array(df)
data_new_x, data_new_y, list_index, self.train_map = self.getDataWithIndexMultiple(arr, self.only_untrain)
data_new_x = np.concatenate(data_new_x, axis=0)
data_new_y = np.concatenate(data_new_y, axis=0)
self.scaler_x, self.scaler_y = MinMaxScaler(), MinMaxScaler()
data_new_x, data_new_y = self.scaler_x.fit_transform(data_new_x), self.scaler_y.fit_transform(data_new_y)
self.list_x, self.list_y = {}, {}
list_symbols = list(self.train_map.keys())
for i in range(len(list_index)):
data_a, data_b = data_new_x[:list_index[i]], data_new_y[:list_index[i]]
data_new_x, data_new_y = data_new_x[list_index[i]:], data_new_y[list_index[i]:]
self.list_x[list_symbols[i]] = data_a
self.list_y[list_symbols[i]] = data_b
return self.list_x, self.list_y, self.scaler_x, self.scaler_y, self.train_map
def getTrainData(self, stockinfo: stockInfo) -> (np.array, np.array):
self.allScaler()
symbol= stockinfo.symbol
x, y = self.list_x[symbol], self.list_y[symbol]
train_x, _ = self.dataSequence(x, nor=False)
train_y, _ = self.dataSequence_y(y, nor=False)
if self.only_untrain:
index = self. trian_map[symbol]
if index >= len(train_x): # 起始训练数据超出是长度,没有数据
return None
train_x, train_y = train_x[index:], train_y[index:]
return train_x,train_y,self.scaler_x,self.scaler_y
class prepareLogistics(prepareBase):
def __init__(self):
super(prepareLogistics,self).__init__()
self.SEQUENCE_LEN=50
time.sleep(10)
pass
class prepareClassify(prepareBase):
def __init__(self):
super(prepareClassify,self).__init__()
time.sleep(10)
pass
def dataProcess(self,df:pd.DataFrame)-> (np.array,np.array):
"""
:param df: 包含stock信息的 dataframe
:return:
"""
# 4.对数据按照表格拆分
column_x, column_y = columnSplit(self.verify)
x, y_t = np.array( df[column_x]),np.array( df[column_y])
# 5拼凑数据 x的最后一行没有预测值 y的第一行没有 训练值
y_value= [ y_t[i+1][0]-y_t[i][0] for i in range(len(y_t)-1)]
y_value= np.int32( np.array( y_value)> 0).reshape(-1, 1)
if self.verify:
y_date= [str(y_t[i][1]) + "->" + str(y_t[i + 1][1]) + " :" + str(y_t[i + 1][0]) + "-" + str(y_t[i][0]) for i in range(len(y_t)-1)]
y_date = np.array(y_date).reshape(-1, 1)
y_value = np.concatenate(( y_date ,y_value),axis=1)
return x[: len(x)-1], y_value
def dataSequence_y(self,y:pd.DataFrame,nor :bool =False )-> np.array:
data_all = np.array(y)
data = []
for i in range(len(data_all) -self.SEQUENCE_LEN + 1):
data.append(data_all[i + self.SEQUENCE_LEN - 1])
return np.array(data) ,None
class prepareLogistics_Ex(prepareBase):
def __init__(self):
super(prepareLogistics_Ex,self).__init__()
self.SEQUENCE_LEN=50
self.tuple_x=3
stock_base = get_info.stock_basic()
self.area_vec,self.indu_vec ,self.area_shape,self.indu_shape= stock_base.getExVec()
time.sleep(10)
pass
def getInputShape(self):
daily_shape = super().getInputShape()
area_shape = (self.area_shape[1],)
indu_shape = (self.indu_shape[1],)
return daily_shape,area_shape,indu_shape
def getTrainData(self, stockinfo: stockInfo) -> (np.array, np.array):
data= super().getTrainData(stockinfo)
if data is None:
return None
train_x, train_y, self.scaler_x, self.scaler_y = data
sample_len = train_x.shape[0]
ex_area ,ex_indu=self.area_vec[stockinfo.symbol],self.indu_vec[stockinfo.symbol]
ex_area,ex_indu = np.tile(ex_area,(sample_len,1)), np.tile(ex_indu,(sample_len,1))
train_x =train_x,ex_area,ex_indu
return train_x,train_y,self.scaler_x,self.scaler_y
if __name__ == "__main__":
# getTrainData(1)
# allScaler()
# for data in dataGenerator(0,True):
# print("data")
pre =prepareClassify()
pre.test()
trainning.py
import sys
# print(sys.executable)
import numpy as np
import matplotlib.pyplot as plt
import gc
import tensorflow as tf
from tensorflow.keras.callbacks import TensorBoard
from tensorflow.keras.callbacks import ModelCheckpoint
import datetime
import tensorflow.keras as keras
import prepare
import stock_sql
from abc import ABCMeta, abstractmethod
import time
import pickle
# gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
# cpus = tf.config.experimental.list_physical_devices(device_type='CPU')
# print(gpus)
# print(cpus)
#
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
# tf.config.experimental.set_visible_devices(devices=gpus[0], device_type='GPU')
# tf.config.experimental.set_visible_devices(devices=cpus[0], device_type='CPU')
# tf.config.experimental.set_virtual_device_configuration(
# gpus[1],
# [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=1536)]
# )
weight_file = "".join(["model/weights/my_weight"])
weight_file_1 = "".join(["model/weights/my_weight_1"])
index_file = "index.txt"
log_dir = "".join(
["log\\model_train\\", datetime.datetime.now().strftime("%Y%m%d-%H%M%S")])
model = None
def getTrainIndex():
with open(index_file, "r") as f:
index = f.readline()
return int(index)
def saveTrainIndex(index):
with open(index_file, "w") as f:
f.write(str(index))
f.flush()
def build_model(shape):
model = model_1(shape)
keras.utils.plot_model(model, 'picture/multi_model.png', show_shapes=True)
return model
def model_2(shape):
model = keras.Sequential()
model.add(keras.Input(shape=shape))
model.add(keras.layers.LSTM(
units=50, activation='tanh', return_sequences=True))
model.add(keras.layers.LSTM(
units=50, activation='tanh', return_sequences=True))
model.add(keras.layers.Dense(units=10, activation="tanh"))
model.add(keras.layers.Dense(units=1))
model.compile(optimizer=keras.optimizers.Adam(),
loss="mse",
metrics=[keras.metrics.mae])
model.summary()
return model
def testing(stockInfo):
"""
:param stockInfo:{symbol: XXX,name}
:return:
"""
return logisticsTrain().predict(stockInfo)
def showData(test_y, predict_y):
x = [i for i in range(len(predict_y))]
plt.plot(x, predict_y, "b*--")
plt.plot(x, test_y, "gv--")
# plt.plot(test_y)
plt.show()
def template():
# 构建一个根据文档内容、标签和标题,预测文档优先级和执行部门的网络
# 超参
num_words = 2000
num_tags = 12
num_departments = 4
# 输入
body_input = keras.Input(shape=(None,), name='body')
title_input = keras.Input(shape=(None,), name='title')
tag_input = keras.Input(shape=(num_tags,), name='tag')
# 嵌入层
body_feat = keras.layers.Embedding(num_words, 64)(body_input)
title_feat = keras.layers.Embedding(num_words, 64)(title_input)
# 特征提取层
body_feat = keras.layers.LSTM(32)(body_feat)
title_feat = keras.layers.LSTM(128)(title_feat)
features = keras.layers.concatenate([title_feat, body_feat, tag_input])
# 分类层
priority_pred = keras.layers.Dense(
1, activation='sigmoid', name='priority')(features)
department_pred = keras.layers.Dense(
num_departments, activation='softmax', name='department')(features)
# 构建模型
model = keras.Model(inputs=[body_input, title_input, tag_input],
outputs=[priority_pred, department_pred])
# model.summary()
keras.utils.plot_model(
model, 'picture/template_model.png', show_shapes=True)
def logisticsMode(shape):
model = keras.Sequential()
model.add(keras.Input(shape=shape))
#model.add(keras.layers.Dense(units=100, activation="tanh"))
model.add(keras.layers.LSTM(units=500, activation='tanh', return_sequences=True))
model.add(keras.layers.LSTM(units=500, activation='tanh', return_sequences=True))
model.add(keras.layers.LSTM(units=200, activation='tanh', return_sequences=True))
model.add(keras.layers.LSTM(units=200, activation='tanh', return_sequences=False))
model.add(keras.layers.Dense(units=200, activation="tanh"))
model.add(keras.layers.Dense(units=20, activation="tanh"))
model.add(keras.layers.Dense(units=1))
model.compile(optimizer=keras.optimizers.Adam(),
loss="mse",
metrics=[keras.metrics.mae])
model.summary()
return model
def classifyModel(shape):
inn = keras.Input(shape=shape)
lstm1 = keras.layers.LSTM(units=500, activation='tanh', return_sequences=True)(inn)
lstm2 = keras.layers.LSTM(units=500, activation='tanh', return_sequences=True)(lstm1)
lstm3 = keras.layers.LSTM(units=200, activation='tanh', return_sequences=True)(lstm2)
lstm4 = keras.layers.LSTM(units=50, activation='tanh', return_sequences=True)(lstm3)
flatten = keras.layers.Flatten()(lstm4)
Dense1 = keras.layers.Dense(units=200, activation="relu")(flatten)
ott = keras.layers.Dense(units=3)(Dense1)
model = keras.Model(inputs=inn, outputs=ott)
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.summary()
# keras.utils.plot_model(model,"picture/classify_model.png",show_shapes=True)
return model
def classifyModel_1(shape):
inn = keras.Input(shape=shape)
lstm1 = keras.layers.LSTM(units=500, activation='tanh', return_sequences=True)(inn)
lstm2 = keras.layers.LSTM(units=500, activation='tanh', return_sequences=True)(lstm1)
lstm3 = keras.layers.LSTM(units=200, activation='tanh', return_sequences=True)(lstm2)
lstm4 = keras.layers.LSTM(units=200, activation='tanh', return_sequences=True)(lstm3)
lstm5 = keras.layers.LSTM(units=100, activation='tanh', return_sequences=True)(lstm4)
flatten = keras.layers.Flatten()(lstm5)
Dense1 = keras.layers.Dense(units=200, activation="relu")(flatten)
Dense2 = keras.layers.Dense(units=100,activation="relu")(Dense1)
ott = keras.layers.Dense(units= 3)(Dense2)
model = keras.Model(inputs=inn, outputs=ott)
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.summary()
# keras.utils.plot_model(model,"picture/classify_model.png",show_shapes=True)
return model
def logisticsModel_Ex(daily_shape,area_shape,indu_shape):
daily_input = keras.Input(shape= daily_shape)
area_input = keras.Input(shape= area_shape)
indu_input = keras.Input(shape= indu_shape)
lstm1 = keras.layers.LSTM(units=500, activation="relu", return_sequences=True)(daily_input)
lstm2 = keras.layers.LSTM(units=200, activation="relu", return_sequences=True)(lstm1)
lstm3 = keras.layers.LSTM(units=200, activation="relu", return_sequences=True)(lstm2)
lstm4 = keras.layers.LSTM(units=200, activation="relu", return_sequences=False)(lstm3)
# dense_area = keras.layers.Dense(units=200, activation="relu")(area_input)
# dense_indu = keras.layers.Dense(units=200, activation="relu")(indu_input)
flatten =keras.layers.concatenate([lstm4,area_input,indu_input])
dense1 = keras.layers.Dense(units=200, activation="relu")(flatten)
dense2 = keras.layers.Dense(units=100, activation ="relu")(dense1)
dense3 = keras.layers.Dense(units=100, activation= "relu")(dense2)
ott = keras.layers.Dense(units=1)(dense3)
model =keras.Model(inputs=[daily_input,area_input,indu_input],outputs=ott)
model.compile(optimizer=keras.optimizers.Adam(),
loss="mse",
metrics=[keras.metrics.mae])
model.summary()
return model
class trainBase:
def __init__(self):
self.dataPre:prepare.prepareBase =None
self.model:keras.Model=None
pass
@abstractmethod
def buildModel(self):
pass
def trainAll(self):
for data in self.dataPre.dataGenerator(0):
train_x, train_y, stock_list = data
print(train_x[0].shape)
print(train_x[1].shape)
print(train_x[2].shape)
print(train_y.shape)
history = self.model.fit(train_x, train_y, batch_size=self.batch_size,
epochs=self.epochs, validation_split=0.05)
self.model.save_weights(self.file_path)
stock_sql.updateTrianList(stock_list)
del train_x, train_y, stock_list
gc.collect()
gc.collect()
def predict(self, stockinfo: prepare.stockInfo) -> (np.array, np.array):
data = stock_sql.getRecordData(stockinfo.symbol)
if data:
return data[0], data[1]
data = self.dataPre.getTestData(stockinfo)
if data is None:
return np.array([]), np.array([])
test_x, real_y,scaler_x, scaler_y = data[0], data[1], data[2], data[3]
predict_y = self.model.predict(test_x)
if scaler_y is not None:
real_y = scaler_y.inverse_transform(real_y)
predict_y = scaler_y.inverse_transform(predict_y.astype("float64"))
real_y, predict_y = np.around( real_y, decimals=2), np.around(predict_y, decimals=2)
stock_sql.SaveRecordData(stockinfo, real_y, predict_y)
return real_y, predict_y
class logisticsTrain(trainBase):
def __init__(self):
self.file_path = "".join(["model/weights/my_weight"])
self.batch_size = 1024
self.epochs = 20
self.buildModel()
def buildModel(self):
print("parepareLogistics")
self.dataPre = prepare.logisticsAllScaler()
self.model = logisticsMode(self.dataPre.getInputShape())
try:
print("load_weights from " + self.file_path)
self.model.load_weights(self.file_path)
print(" success ")
except:
print("load_weight failed")
pass
time.sleep(10)
class classifyTrain(trainBase):
def __init__(self):
self.file_path = "".join(["model/weights/classify_weight"])
self.batch_size = 1024
self.epochs = 5
self.buildModel()
def buildModel(self):
self.dataPre = prepare.prepareClassify()
print("prepareClassify")
self.model = classifyModel(self.dataPre.getInputShape())
try:
print("load_weights from " + self.file_path)
self.model.load_weights(self.file_path)
print(" success ")
except:
print("load_weight failed")
pass
time.sleep(10)
class classifyTrain_1(trainBase):
def __init__(self):
self.file_path = "".join(["model/weights/classify_weight_1"])
self.batch_size = 1024
self.epochs = 5
self.buildModel()
def buildModel(self):
self.dataPre = prepare.prepareClassify()
print("prepareClassify")
self.model = classifyModel_1(self.dataPre.getInputShape())
try:
print("load_weights from " + self.file_path)
self.model.load_weights(self.file_path)
print(" success ")
except:
print("load_weight failed")
pass
time.sleep(10)
class logisticsTrain_Ex(trainBase):
def __init__(self):
self.file_path ="".join(["model/weights/logisticsEx_weight"])
self.batch_size = 512
self.epochs = 5
self.buildModel()
def buildModel(self):
self.dataPre = prepare.prepareLogistics_Ex()
print("prepareClassifyEx")
daily_shape, area_shape, indu_shape = self.dataPre.getInputShape()
self.model = logisticsModel_Ex(daily_shape, area_shape, indu_shape)
try:
print("load_weights from " + self.file_path)
self.model.load_weights(self.file_path)
print(" success ")
except:
print("load_weight failed")
pass
time.sleep(10)
pass
if __name__ == '__main__':
# stock_sql.initTrainDate()
train = logisticsTrain_Ex()
train.dataPre.only_untrain=False
train.trainAll()
