股票模型预测

简介:

处理过程大体如下:
1、使用tushare获取stock信息
2、对数据进行处理,做好train_x和train_y的对应关系
3、训练和预测 ,网络是学习x也y之间的映射关系,网络要与数据匹配

1、get_info.py 用于从获取stock信息,目前只有日线数据,后面会增加(tushare积分不够,有些数据获取不到,大家注册下给点积分)
2、stock_sql.py 用于将部分信息记录到数据库,方便查询检索。数据库使用的sqlite3
3、prepare.py 用于对数据进行处理,生成train_x,train_y的对应关系,满足网络训练需要。NN可以学习映射规律,进行预测。
4、trainning.py 网络模型,并进行训练。
5、evaluate.py 对模型进行简单的评估。
6、server.py flask做的后台用于数据展示。

代码实现:

get_info.py

from concurrent.futures import ProcessPoolExecutor, as_completed

from pandas import DataFrame

import stock_sql
import pandas as pd
import numpy as np
import datetime
import time
import tushare as ts

import jieba
import re
from gensim.test.utils import common_texts, get_tmpfile
from gensim.models import Word2Vec,word2vec
from gensim.models.doc2vec import Doc2Vec, TaggedDocument

from sklearn.preprocessing import OneHotEncoder

#path = get_tmpfile("word2vec.model")

print(ts.__version__)

ts.set_token("d1af48f518c17415b1b98b2ce84ab7b1a0025adfdde78e22513b31ec")
pro = ts.pro_api()

d = {
    "ts_code": "TS代码",
     "symbol": "股票代码",
     "name": "股票名称",
     "area": "所在地域",
     "industry": "所属行业",
     "fullname": "股票全称",
     "enname": "英文全称",
     "market": "市场类型(主板/中小板/创业板/科创板)",
     "exchange": "交易所代码",
     "curr_type": "交易货币",
     "list_status": "上市状态:L上市 D退市 P暂停上市",
     "list_date": "上市日期",
     "delist_date": "退市日期",
     "is_hs": "是否沪深港通标的,N否 H沪股通 S深股通",
     }

daily = {
    "ts_code": "股票代码",
    "trade_date": "交易日期",
    "open": "开盘价",
    "high": "最高价",
    "low": "最低价",
    "close": "收盘价",
    "pre_close": "昨收价",
    "change": "涨跌额",
    "pct_chg": "涨跌幅(未复权)",
    "vol": "成交量(手)",
    "amount": "成交额(千元)",
}

def dateRange(beginDate, endDate):
    dates = []
    dt = datetime.datetime.strptime(beginDate, "%Y%m%d")
    date = beginDate[:]
    while date <= endDate:
        dates.append(date)
        dt = dt + datetime.timedelta(1)
        date = dt.strftime("%Y%m%d")
    return dates

class stock_basic:
    """
    股票基本信息
    """
    basic_info = {
        "ts_code": "TS代码",
        "symbol": "股票代码",
        "name": "股票名称",
        "area": "所在地域",
        "industry": "所属行业",
        "fullname": "股票全称",
        "enname": "英文全称",
        "market": "市场类型(主板/中小板/创业板/科创板)",
        "exchange": "交易所代码",
        "curr_type": "交易货币",
        "list_status": "上市状态:L上市 D退市 P暂停上市",
        "list_date": "上市日期",
        "delist_date": "退市日期",
        "is_hs": "是否沪深港通标的,N否 H沪股通 S深股通",
    }

    def __init__(self):
        self.file_path ="data/stock_basic.csv"
        self.record_path= "data/data.csv"

    def getBasicInfo(self):
        # 获取信息
        fields = list(self.basic_info.keys())
        headers = list(self.basic_info.values())
        data: DataFrame = pro.query('stock_basic', exchange='',
                                    list_status='L', fields=fields)
        # print(data.values)
        data.columns = headers
        data.to_csv(self.file_path, header=True, encoding='utf-8', index=False)

    def recordDateInfo(self):
        # 初始化跟新时间
        data = pd.read_csv("data/stock_basic.csv", dtype=object)
        # a 股票的全部信息 ,b 要去掉的几列
        a = list(d.values())
        b = ["TS代码", "股票名称", "上市日期", "退市日期"]
        today = datetime.date.today().strftime("%Y%m%d")

        header = [val for val in a if val not in b]
        print(header)
        data_1 = data.drop(header, axis=1)
        data_1.insert(loc=len(data_1.columns), column="更新时间",
                      value=today, allow_duplicates=False)

        for row in data_1.values:
            row[4] = row[2]

        data_1.to_csv("data/data.csv", header=True, encoding='utf-8', index=False)

    def getExVec(self):
        df = stock_sql.getStockFrame(headers=["symbol", "name", "area", "industry"])
        df.sort_values(by=["symbol"], inplace=True, ignore_index=True)
        df.fillna(" ",inplace=True)
        area_vec = OneHotEncoder(sparse=False).fit_transform(df[["area"]])
        indu_vec = OneHotEncoder(sparse=False).fit_transform(df[["industry"]])
        # data = pd.read_csv("data/stock_basic.csv", dtype=object)
        # aa = OneHotEncoder().fit_transform(data[["所在地域"]])
        data_area= {}
        data_indu={}
        symbols = np.array(df[["symbol"]])
        # areas =np.array(df[["area"]]).reshape(-1)
        # data_t={}
        for i in range(len(symbols)):
            data_area[symbols[i][0]] = area_vec[i]
            data_indu[symbols[i][0]] = indu_vec[i]
        return  data_area, data_indu, area_vec.shape,indu_vec.shape

class stock_daily(object):

    def __init__(self):
        pass

    def updateDailyInfo(self):
        """
        更新所有 stock 的日信息
        :return:
        """
        df = stock_sql.getTradeDateList()
        # 按照symbol 升级
        # df.sort_values(by=["symbol"], inplace=True, ignore_index=True)
        stock_list = np.array(df)
        self. updateDailyList(stock_list)
        # print(stock_list)

    def createDateTable(self):
        """
        创建数据库用来记录 更新的日期信息 ,并进行根据 stockBasic信息进行初始化
        :return:
        """
        stock_sql.createDateTable()
        stock_sql.initDateTable()

    def getDailyInfo(self, symbol, ts_code, start_time:str, end_time:str) -> pd.DataFrame:
        """
        获取stock 日信息
        :param symbol: stock 代码
        :param ts_code:
        :param start_time:  起始日期
        :param end_time:    结束日期
        :return:   daily 信息
        """
        df_list = []
        while start_time != end_time:
            df_t = pro.daily(
                ts_code=ts_code, start_date=start_time, end_date=end_time)
            df_list.append(df_t)
            end_time = df_t.tail(1)["trade_date"].values[0]

        if len(df_list) != 0:
            data = pd.concat(df_list, join="inner")
        else:
            data = None

        return data

    def dailyInfoConcat(self, symbol: int, data:pd.DataFrame):
        """
        将新增的日新增信息与原记录的信息合并保存
        :param symbol: stock 代码
        :param data: 新增数据
        :return: 最新更新日期
        """
        path = stock_sql.getFilePath(symbol)
        df = stock_sql.getStockData(path)
        data.trade_date = data.trade_date.astype("int64")
        # 新数据放在前
        data_n = pd.concat([data, df], join='inner', ignore_index=True)
        # 1.数据去重
        data_n.drop_duplicates(subset=["trade_date"], keep='first', inplace=True)
        data_n.to_csv(path, header=True, encoding='utf-8', index=False)
        return data_n.head(1)["trade_date"].values[0]

    def recordLastDate(self,symbol,data):
        """
        记录最后dailyinfo的最后更新日期
        :param symbol: stock 代码
        :param data: 最后更新日期
        :return:  none
        """
        stock_sql.updateTradeDate(symbol, data)

    def updateDailyList(self, stockList):
        """
        更新stocklist中的stock 日信息
        :param stockList:输入的股票列表包含 symbol name tscode  lastdate
        :return:
        """
        today = datetime.date.today().strftime("%Y%m%d")
        for item in stockList:
            # time.sleep(.3)
            print(item)
            data = self.getDailyInfo(item[0], item[2], item[3], "20200312")
            if data is not None:
                last_d = self.dailyInfoConcat(item[0], data)
                self.recordLastDate(item[0], last_d)

    def updateDailyListMutliple(self,stockList):
        """
        更新stocklist中的stock 日信息 (多线程)
        :param stockList:输入的股票列表包含 symbol name tscode  lastdate
        :return:
        """
        #对列表进行拆分
        stock_split = stock_sql.arr_split( stockList ,500)
        executor = ProcessPoolExecutor(max_workers=8)
        all_task = [executor.submit( self.updateDailyList, list_s) for list_s in stock_split]
        for future in as_completed(all_task) :
            res = future.result()

class stock_company(object):
    """
    股票相关行业信息
    """

    def __init__(self):
        self.file_path = "data/stock_company.csv"
        self.company_info = {
            "ts_code": "股票代码",
            "province": "所在省份",
            "city": "所在城市",
            "main_business": "主营业务",
            "business_scope": "经营范围"
        }

    def getCompanyInfo(self):
        df:DataFrame =pro.stock_company(fields =  list( self.company_info.keys()))
        df.to_csv(self.file_path,header=True, encoding='utf-8', index=False)
        print(df)

    def getCompanyData(self):
        drops = ["ts_code"]
        columns =[x for x in self.company_info.keys() if x not in drops ]
        df :DataFrame = pd.read_csv(self.file_path,index_col=False,na_values="")
        df.sort_values(by=["ts_code"], ascending=True, ignore_index=True, inplace=True)
        df = df[columns]
        return df

    def removePunctuation(self,text):
        text = re.sub(r'[{}]+'.format('!,;:?()[]\n,。():;、\''), ' ', text)
        return text.strip().lower()

    def word_cut_test(self):
        sentences = ["吸收公共存款",
                     "2015年我毕业于西安科技大学",
                     "2015年我毕业于西安电子科技大学",
                     "2015年我毕业于西安建筑科技大学",
                     "2015年我毕业于西安交通大学",
                     "2015年我毕业于北京大学"]

        for sentence in sentences:
            # 全模式
            words = jieba.cut(sentence, cut_all=True)
            print("全模式:  %s" % " ".join(words))

            words = jieba.cut(sentence, use_paddle=True)
            print("新词模式:  %s" % " ".join(words))
            # 默认精确模式
            words = jieba.cut(sentence)
            print("精确模式:  %s" % " ".join(words))

            # 搜索模式
            words = jieba.cut_for_search(sentence)
            print("搜索模式:  %s" % " ".join(words))

    def word_cut(self):
        jieba.enable_paddle()
        df =self.getCompanyData()
        df =df[["business_scope"]]
        arr = np.array(df)
        stopwords = self.stopwordslist("nlp/stop_words.txt")
        with open("nlp/cut_words.txt","w+", encoding='utf-8') as fw:
            for i in range(len(arr)):
                list=self.removePunctuation(arr[i][0])
                item = jieba.cut(list,use_paddle=True) # 使用paddle模式
                santi_words =[x for x in item if len(x) > 1 and x not in stopwords]
                fw.writelines(santi_words)
                fw.write("\r\n")
                print(santi_words)

    def stopwordslist(self,filepath):
        stopwords = [line.strip() for line in open(filepath, 'r', encoding='utf-8').readlines()]
        return stopwords

    def word2vec(self):
        sentences = word2vec.PathLineSentences("nlp/cut_words.txt")
        model = Word2Vec(sentences, size=20, window=5, min_count=1, workers=4)
        model.save("nlp/word2vec.model")
        model = Word2Vec.load("nlp/word2vec.model")
        # a= model.train([["吸收公众存款", "吸收公众存款"]], total_examples=1, epochs=1)
        vector = model.wv['新材料']
        a=model.similar_by_vector(vector)
        print(a)
        print(vector)

    def doc2vec(self):
        sentences = word2vec.PathLineSentences("nlp/cut_words.txt")
        documents = [TaggedDocument(doc, [i]) for i, doc in enumerate(sentences)]
        # for  i ,doc in  enumerate(sentences):
        #     ddd = TaggedDocument(doc,[i])
        #     print(ddd)
        model = Doc2Vec(documents, vector_size=20, window=2, min_count=1, workers=4)
        model.save("nlp/doc2vec.model")
        model = Doc2Vec.load("nlp/doc2vec.model")
        vector = model.infer_vector(["电器开关零部件及附件制造"])
        model.similar_by_vector(vector)
        pass

if __name__ == "__main__":
    # stock_record()
    # stock_data()
    # company= stock_company()
    # company.doc2vec()
    basic  =stock_basic()
    basic.getExVec()

stock_sql.py

from datetime import datetime
import pandas as pd
import numpy as np
import sqlite3
import os
import pickle

d = {"ts_code": "TS代码",
     "symbol": "股票代码",
     "name": "股票名称",
     "area": "所在地域",
     "industry": "所属行业",
     "fullname": "股票全称",
     "enname": "英文全称",
     "market": "市场类型(主板/中小板/创业板/科创板)",
     "exchange": "交易所代码",
     "curr_type": "交易货币",
     "list_status": "上市状态:L上市 D退市 P暂停上市",
     "list_date": "上市日期",
     "delist_date": "退市日期",
     "is_hs": "是否沪深港通标的,N否 H沪股通 S深股通",
     }

DataBase = "data\my_stock.db"
RECORDTABLE = ''

class my_sql:

    def __init__(self, baseName):
        self.conn = sqlite3.connect(baseName)

    def __enter__(self):
        # print('__enter__() is call!')
        return self.conn.cursor()

    def __exit__(self, exc_type, exc_value, traceback):
        self.conn.commit()
        self.conn.close()
        # print('__exit__() is call!')
        # print(f'type:{exc_type} : ' + f'value:{exc_value}')
        # print(f'value:{exc_value}')
        # print(f'trace:{traceback}')
        return True  # 异常不抛出

def CreateStockTabel():
    Sql_str = 'CREATE TABLE stock_info ('
    for item in list(d.keys()):
        Sql_str += item + " CHAR(50), "
    Sql_str = Sql_str[:-2]
    Sql_str += ") ;"

    print(Sql_str)
    with my_sql(DataBase) as c:
        c.execute(Sql_str)

def createRecordTable():
    Sql_str = 'CREATE TABLE  IF NOT EXISTS stock_record (' \
              'symbol TEXT PRIMARY KEY, name TEXT, time TEXT, real_data BLOB, predict_data BLOB )'
    with my_sql(DataBase) as c:
        c.execute(Sql_str)

def createDateTable():
    """
    创建stock_date 数据表 用来记录daily  train等日期信息
    :return:
    """
    Sql_str = 'CREATE TABLE  IF NOT EXISTS stock_date (' \
              'symbol TEXT PRIMARY KEY, name TEXT, ts_code TEXT, ' \
              'list_date TEXT, trade_date TEXT, train_date TEXT)'

    with my_sql(DataBase) as c:
        c.execute(Sql_str)

def saveStockData():
    """
    将stock的基本信息存放再数据库
    :return:
    """
    df = pd.read_csv("data/stock_basic.csv", dtype=object)
    df.columns = list(d.keys())
    conn = sqlite3.connect(DataBase)
    df.to_sql(name="stock_info", con=conn, index=False, if_exists="replace")
    conn.close()

def initDateTable():
    """
    生成 表用来记录 stock信息更新到哪一天, 模型对训练到哪一天
    :return:
    """
    sqlStr = "SELECT symbol, name, ts_code,list_date FROM stock_info "
    conn = sqlite3.connect(DataBase)
    df = pd.read_sql(sqlStr, con=conn)
    conn.close()
    with my_sql(DataBase) as c:
        for index, row in df.iterrows():
            sqlStr = "REPLACE INTO stock_date (symbol, name, ts_code, list_date, info_date, train_date) " \
                     "VALUES (?,?,?,?,?,?)"
            c.execute(sqlStr,(row["symbol"],row["name"],row["ts_code"],
                              row["list_date"],row["list_date"],row["list_date"]))

def saveStockPath():
    """
    将stock的位置信息存入数据库
    :return:
    """
    sqlStr = "SElECT name, symbol, ts_code FROM stock_info "
    conn = sqlite3.connect(DataBase)
    df = pd.read_sql(sqlStr, con=conn)

    paths = []
    for item in df.values:
        path_o = "".join(["data/info/", item[0].replace("*", ""), "-", item[2], ".csv"])
        path_w = "".join(["data/info/", item[1], ".csv"])
        paths.append(path_w)
        # os.rename(path_o,path_w)
    df.drop(columns=["ts_code"], inplace=True)
    df.insert(loc=2, column="path", value=paths)
    df.to_sql(name="stock_path", con=conn, index=False, if_exists='replace')
    conn.close()
    print(df)

def initTrainDate():
    """
    初始化 训练日期
    :param symbol:
    :param date:
    :return:
    """
    sqlStr = "UPDATE stock_date SET train_date = list_date"
    with my_sql(DataBase) as c:
        c.execute(sqlStr)
    print("initTrainDate")

def updateTradeDate(symbol,date):
    """
    更新stock的 日期信息 记录日线更新到哪那一天
    :param symbol: stock 代码
    :param date:   交易日期
    :return:
    """
    sqlStr = "UPDATE stock_date SET trade_date ='{}' WHERE symbol='{}'".format(date, symbol)
    with my_sql(DataBase) as c:
        c.execute(sqlStr)

def updateTrianList(stockList):
    """
    更新股票的 训练日期
    :param dateList:输入待跟新的stock 列表  包含 symbol, name
    :return:
    """
    with my_sql(DataBase) as c:
        for item in stockList:  # symbol date tscode
            sqlStr = "UPDATE stock_date SET train_date = trade_date WHERE symbol='{}'".format(item[0])
            c.execute(sqlStr)

def getPathList(symbolList):
    """
    获取stock 信息的文件位置列表
    :param symbolList: 输入列表包含可 symbol和 name 信息
    :return: pathList  输出列表 包含 path 和 symbol
    """
    pathList=[]
    with my_sql(DataBase) as c:
        for item in symbolList: #symbol :(symbol,name)
            sqlStr = "SELECT path FROM stock_path WHERE symbol ='{}'".format(item[0])
            cursor = c.execute(sqlStr)
            path = cursor.fetchone()[0]
            pathList.append((path,item[0]))
    return pathList #(path,symbol)

def getFilePath(symbol):
    sqlStr = "SELECT path FROM stock_path WHERE symbol = \"" + symbol + "\""
    with my_sql(DataBase) as c:
        cursor = c.execute(sqlStr)
        path = cursor.fetchone()[0]
        return path

def getLastTradeList(pathList):
    """
    获取stock 最后更新信息
    :param pathList:输出列表 包含 path 和 symbol
    :return:dateList 输出列表 包含 symbol date tscode)
    """
    dateList =[] # (symbol date tscode)
    for item in pathList: #item : (path, symbol)
        df = pd.read_csv(item[0], dtype=object)
        df.sort_values(by=["trade_date"], inplace=True, ignore_index=True)
        date = df.tail(1)[["trade_date", "ts_code"]].values[0]
        dateList.append((item[1],date[0],date[1]))
    return  dateList

def getLastDate(symbol):
    """
    获取stock的训练数据的最后日期:
    :param symbol: stock的symbol信息
    :return:  : trade_date,train_date
    """
    columns = ["symbol", "trade_date", "train_date"]
    sqlStr = "SELECT {} From stock_date WHERE symbol= '{}'".format(columnsToSql(columns),symbol)

    with my_sql(DataBase) as c:
        cursor= c.execute(sqlStr)
        data = cursor.fetchone()
        return  data[1],data[2]

def getLastTradeDate(symbol):
    """
    获取stock的训练数据的最后日期:
    :param symbol: stock的symbol信息
    :return:  nparray: [date,tscode]
    """
    path =getFilePath(symbol)
    df =pd.read_csv(path,dtype=object)
    df.sort_values(by=["trade_date"],inplace=True,ignore_index=True)
    date = df.tail(1)[["trade_date","ts_code"]]
    return date.values[0]

def getTradeDateList():
    sqlStr= "SELECT symbol, name, ts_code,trade_date FROM stock_date"
    conn =sqlite3.connect(DataBase)
    df= pd.read_sql(sqlStr,con=conn)
    conn.close();
    return df

def getStockFrame(headers = ["symbol","name"]):
    sqlStr = "SELECT {} FROM stock_info".format(columnsToSql(headers))
    conn = sqlite3.connect(DataBase)
    df = pd.read_sql(sqlStr, con=conn, columns=headers)
    conn.close();
    return df

def getStockInfo(item):
    if not item:
        return pd.DataFrame()

    columns = ["symbol", "area", "industry", "market", "list_date", "is_hs"]
    sqlStr = "SELECT {} From stock_info WHERE symbol= '{}'".format(columnsToSql(columns),item["symbol"])
    conn = sqlite3.connect(DataBase)
    df = pd.read_sql(sqlStr, con=conn)
    conn.close()
    df.fillna("", inplace=True)

    if (df.values.shape[0] != 1):
        return pd.DataFrame()
    for index, key, in enumerate(columns):
        name, value = parseStockInfo(key, df.values[0][index])
        print(name, value)
        df.values[0][index] = name + ": " + value
    return df

def getDailyFrame(symbol:str)->pd.DataFrame:
    path = getFilePath(symbol)
    df = getStockData(path)
    # 2.将nan数据 替换为0
    df.fillna(0, inplace=True)
    # 3.按照日期排序 ,并忽略索引
    df.sort_values(by=["trade_date"], ascending=True, ignore_index=True, inplace=True)
    return df

def getStockData(path):
    data = pd.read_csv(path, index_col=False, na_values=0)
    # 1.数据去重
    data.drop_duplicates(subset=["trade_date"], keep='first', inplace=True)
    return data

def getRecordData(symbol):
    if not symbol :
        return None
    now_date = datetime.now().strftime("%Y%m%d")
    sqlStr = "SELECT real_data, predict_data, symbol FROM stock_record" \
             " WHERE symbol='{}' AND time='{}'".format(symbol,now_date)
    with my_sql(DataBase) as c:
        cursor = c.execute(sqlStr)
        data = cursor.fetchone()
        if data :
            real, predict = pickle.loads(data[0]),pickle.loads(data[1])
            return (real,predict)

def SaveRecordData(stockInfo, realData, predictData):
    reald, pred = pickle.dumps(realData), pickle.dumps(predictData)
    now_date = datetime.now().strftime("%Y%m%d")

    sqlStr = "REPLACE INTO stock_record (symbol,name,time,real_data, predict_data) VALUES (?,?,?,?,?) "
    with my_sql(DataBase) as c:
        c.execute(sqlStr, (stockInfo["symbol"], "", now_date, reald, pred))
    #print(now_date)
    pass

def ClearRecordData():
    sqlStr ="DROP TABLE stock_record"
    with my_sql(DataBase) as c:
        c.execute(sqlStr)

def parseStockInfo(key, value):
    name = d[key]
    if key == 'market':
        name = "市场类型"
    elif key == "curr_type":
        if value == "CNY":
            value = "RMB"
    elif key == "list_status":
        name = "上市状态"
        if value == "L":
            value = "上市"
        elif value == "D":
            value = "退市"
        elif value == "P":
            value = "暂停上市"
    elif key == "is_hs":
        name = "沪深港通"
        if value == "S":
            value = "深港通"
        elif value == "H":
            value = "沪港通"
        elif value == "N":
            value = "否"
    return name, value

def columnsToSql(columns):
    """
    把列表转换成对应的sql 列字符串
    :param columns:
    :return:
    """
    columnStr = ""
    for x in columns:
        columnStr += x + ","
    return  columnStr[:-1]

def arr_split(arr,size):
    s = [arr[i:i+size]  for i in range(0,int(len(arr))+1,size)]
    return s

def DataBaseInit():
    # CreateStockTabel()
    saveStockData()

def sqlTest():
    con = sqlite3.connect("data\my_stock.db")
    sql = 'select * from user_information LIMIT 3'
    df = pd.read_sql(sql, con)
    print(df)

if __name__ == '__main__':
    # DataBaseInit()
    # getStockList()
    # saveStockPath()
    initTrainDate()
    #getFilePath("000001")

prepare.py

# -*-coding:utf-8 -*-
import sys
#print(sys.executable)
import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
import  stock_sql
from abc import ABCMeta, abstractmethod
from multiprocessing import cpu_count
from concurrent.futures import ProcessPoolExecutor,as_completed
import time
import get_info
import  datetime

#pd.show_versions()
daily = {
    "ts_code": "股票代码",
    "trade_date": "交易日期",
    "open": "开盘价",
    "high": "最高价",
    "low": "最低价",
    "close": "收盘价",
    "pre_close": "昨收价",
    "change": "涨跌额",
    "pct_chg": "涨跌幅(未复权)",
    "vol": "成交量(手)",
    "amount": "成交额(千元)",
}

class stockInfo:
    def __init__(self,symbol:str ="000001",name:str="平安银行"):
        self.symbol=symbol
        self.name=name

SEQUENCE_LEN=20
RREDICT_LEN=30

def parseUntrainedData(symbol,df):
    """
    解析未训练用的数据
    :return:
    """
    index =0
    dateInfo =stock_sql.getLastDate(symbol)
    lastDate = int(dateInfo[1])
    for i,row in df.iterrows():
        if row["trade_date"] >= lastDate:
            index  = i
            #df= df.iloc[:index+ SEQUENCE_LEN]
            break
    if index < SEQUENCE_LEN -1 : #数据从未训练过 应该从0开始训练
        index =0
    elif SEQUENCE_LEN -1 <= index: # 数据训练过:
        index = index+1  -(SEQUENCE_LEN -1)
    return index

def getTrianColnum(stockinfo:stockInfo,length=RREDICT_LEN):
    symbol = stockinfo.symbol
    df = stock_sql.getDailyFrame(symbol)
    # 对数据进行训了测试的分离处理
    x, y = prepareLogistics().dataProcess(df)
    col =["close"]
    close = np.array(x[col])[-length:]
    return  np.array( close)

def columnSplit(verify=False):
    """
    将数据按照列的方式拆分成 x和y
    verify :验证y 的日期是否正确 开启时 将带上y对应的日期
    :return:
    """
    all = list(daily.keys());
    column_y=['ts_code']
    column_x = [x for x in all if x not in column_y]
    if verify:# 带上日期 验证 日期是否对应
        column_y = ["close",'trade_date']
    else:
        column_y = ["close"]
    return  column_x ,column_y

def getPredictTrainnData( stockInfo,length =RREDICT_LEN):
    #获取stock 原始信息 ,去重 、预测数据对其
    symbol = stockInfo["symbol"]
    df = stock_sql.getDailyFrame(symbol)
    column_x, column_y = columnSplit()
    x = np.array(df[column_x])
    x= np.around(x,decimals=2)
    return  x[-length:]

def dataSplit(npArr, split=0.9):
    """
    :param npArr:
    :param split:
    :return:
    """
    #对数据进行拆分 获得验证数据集
    split_boundary = int(npArr.shape[0] * split)
    train_x = npArr[: split_boundary]
    test_x = npArr[split_boundary:]
    return train_x, test_x

class prepareBase:
    def __init__(self):
        print(str( type(self) )+" __build __")
        self.verify = False
        self.only_untrain = False
        self.batch_size = 200
        self.SEQUENCE_LEN = 50
        self.RREDICT_LEN=30
        self.tuple_x=1
        pass

    def getInputShape(self):
        """
        获取训练数据的shape
        :return:  shape
        """
        column_x, _ = columnSplit()
        return (self.SEQUENCE_LEN, len(column_x))

    def dataProcess(self,df:pd.DataFrame)->(np.array,np.array):
        """
        :param df:
        :return:
        """
        # 4.对数据按照表格拆分
        column_x, column_y = columnSplit(self.verify)
        x, y = df[column_x], df[column_y]
        # 5拼凑数据 x的最后一行没有预测值 y的第一行没有 训练值
        x, y = x.drop(len(x) - 1, axis=0), y.drop(0, axis=0)
        y.reset_index(drop=True, inplace=True)
        return x, y

    def dataSequence(self,x:pd.DataFrame,nor:bool=True)->np.array:
        """
        将dataFrame数据 组成训练序列
        :param df: 原始dataFrame数据
        :param nor: 是否对数据标准化
        :param len: 序列长度
        :return: 序列化数据以及 标准化的Scaler
        """
        scaler = MinMaxScaler()
        if nor :
            data_all = np.array(x).astype("float64")
            data_all = scaler.fit_transform(data_all)
        else:
            data_all = np.array(x)
        data = []
        for i in range(len(data_all) - self.SEQUENCE_LEN  + 1):
            data.append(data_all[i: i +  self.SEQUENCE_LEN])
        x = np.array(data).astype('float64')

        return x, scaler

    def dataSequence_y(self,y:pd.DataFrame,nor:bool=True)->np.array:
        """
        将dataFrame数据 组成训练序列
        :param df: 原始dataFrame数据
        :param nor: 是否对数据标准化
        :param len: 序列长度
        :return: 序列化数据以及 标准化的Scaler
        """
        scaler = MinMaxScaler()
        if nor :
            data_all = np.array(y).astype("float64")
            data_all = scaler.fit_transform(data_all)
        else:
            data_all = np.array(y)
        data = []
        for i in range(len(data_all) - self.SEQUENCE_LEN + 1):
            data.append(data_all[i + self.SEQUENCE_LEN - 1])
        x = np.array(data)

        return x, scaler

    def test(self):
        print("+++++++++++++++++++++++")
        self.verify = True
        self.only_untrain = True
        df = stock_sql.getStockFrame()
        df.sort_values(by=["symbol"], inplace=True, ignore_index=True)
        stock_list = np.array(df)
        for info in stock_list:
            data = self.getTrainData(stockInfo(symbol=info[0]))
            data_x, data_y = data[0], data[1]
            for i in range(len(data_x)):
                print(data_x[i][self.SEQUENCE_LEN - 1])
                print(data_y[i])

    def getTrainData(self, stockinfo: stockInfo) -> (np.array, np.array):
        """
           获取单只股票的训练数据
           :param stockInfo: stock 信息
           :param only_untrain: 只使用未训练的数据
           :return: 如果数据不够返回None
           """
        symbol = stockinfo.symbol
        df = stock_sql.getDailyFrame(symbol)
        if len(df) <= self.SEQUENCE_LEN:
            return None
        # 对数据进行训了测试的分离处理
        x, y = self.dataProcess(df)
        train_x, scaler_x = self.dataSequence(x, nor= not self.verify)
        train_y, scaler_y = self.dataSequence_y(y,nor =not self.verify)
        if self.only_untrain:
            index = parseUntrainedData(symbol, df)
            if index >= len(train_x):
                # 起始训练数据超出是长度,没有数据
                return None
            train_x, train_y = train_x[index:], train_y[index:]
        return train_x, train_y,scaler_x,scaler_y

    def getTestData(self, stockInfo:stockInfo):
        data =self.getTrainData(stockInfo)
        if data is not None and len(data[0])>=0:
            train_x, train_y =data[0],data[1]
            return  train_x[-self.RREDICT_LEN:], train_y[-self.RREDICT_LEN:],data[2],data[3]
        else:
            return None

    def dataGenerator(self, index: int = 0):
        df = stock_sql.getStockFrame()
        df.sort_values(by=["symbol"], inplace=True, ignore_index=True)
        # 截取index以后的数据
        stock_array = np.array(df)[index:]
        batch_array = stock_sql.arr_split(stock_array,  self.batch_size)

        for s_array in batch_array:
            # 多进程
            list_x, list_y =  self.dataConcat(s_array)
            # list_x,list_y =dataConcat(s_array,only_untrain)
            if len(list_x[0]) == 0 or len(list_y) == 0: continue
            data_x = []
            for i in range(self.tuple_x):
                data_x.append( np.concatenate(list_x[i], axis=0))
            data_y =  np.concatenate(list_y, axis=0)
            yield data_x, data_y, s_array
            # yield (getTrainData({"symbol":info[0]},only_untrain),info)

    def dataConcat(self,stock_list):
        """
           批量处理stock list 数据  将每只处理的stock数据放到 listx 和 listy中
           :param stock_list:  待处理的stock list 包含symbol name 信息
           :param only_untrain:  只处理未使用的数据
           :return: list_x,list_x
        """
        list_x, list_y = [], []
        for i in range(self.tuple_x):
            list_x.append([])

        for info in stock_list:
            data =self.getTrainData(stockInfo(symbol=info[0]))
            print("stock info :" + info + "\r\n")
            if data == None: continue
            for i in  range(self.tuple_x):
                list_x[i].append(data[0][i])
            list_y.append(data[1])
        return list_x, list_y

    def dataConcatMultiple(self, stock_array, only_untrain=False):
        list_x, list_y = [], []
        for i in range(self.tuple_x):
            list_x.append([])
        split_array = np.array_split(stock_array, 4, axis=0)
        executor = ProcessPoolExecutor(max_workers=4)

        all_task = [executor.submit(self.dataConcat, list_a) for list_a in split_array]
        for future in as_completed(all_task):
            res = future.result()
            for i in range(self.tuple_x):
                list_x[i].extend(res[i][0])
            list_y.extend(res[1])
        return list_x, list_y

class logisticsAllScaler(prepareBase):
    def __init__(self):
        super(logisticsAllScaler,self).__init__()
        self.SEQUENCE_LEN=50
        self.list_x= None
        time.sleep(10)
        pass

    def getDataWithIndex(self, stock_arr, only_untrain=False):
        list_x, list_y, index_list = [], [], []
        trian_map = {}
        index = 0
        for item in stock_arr:
            print(item[1])
            df = stock_sql.getDailyFrame(item[0])
            if only_untrain:
                index = parseUntrainedData(item[0], df)
            # df.drop(columns=["ts_code"],inplace=True)
            x, y = self.dataProcess(df)
            trian_map[item[0]] = index
            list_x.append(x)
            list_y.append(y)
            index_list.append(len(x))
        return list_x, list_y, index_list, trian_map

    def getDataWithIndexMultiple(self,stock_arr, only_untrain=False):
        list_x, list_y, index_list = [], [], []
        trian_map = {}
        a_split = np.array_split(stock_arr, 4, axis=0)
        executor = ProcessPoolExecutor(max_workers=4)
        all_tasks = [executor.submit(self.getDataWithIndex, stockS_arr, only_untrain) for stockS_arr in a_split]
        for furture in as_completed(all_tasks):
            res = furture.result();
            list_xt, list_yt, index_listt, trian_mapt = res[0], res[1], res[2], res[3]
            list_x.extend(list_xt)
            list_y.extend(list_yt)
            index_list.extend(index_listt)
            trian_map.update(trian_mapt)
        return list_x, list_y, index_list, trian_map

    def allScaler(self):
        if self.list_x is not None:
            return self.list_x, self.list_y, self.scaler_x, self.scaler_y, self.train_map

        df = stock_sql.getStockFrame()
        df.sort_values(by=["symbol"], inplace=True, ignore_index=True)
        arr = np.array(df)
        data_new_x, data_new_y, list_index, self.train_map = self.getDataWithIndexMultiple(arr, self.only_untrain)
        data_new_x = np.concatenate(data_new_x, axis=0)
        data_new_y = np.concatenate(data_new_y, axis=0)
        self.scaler_x, self.scaler_y = MinMaxScaler(), MinMaxScaler()
        data_new_x, data_new_y = self.scaler_x.fit_transform(data_new_x), self.scaler_y.fit_transform(data_new_y)
        self.list_x, self.list_y = {}, {}
        list_symbols = list(self.train_map.keys())
        for i in range(len(list_index)):
            data_a, data_b = data_new_x[:list_index[i]], data_new_y[:list_index[i]]
            data_new_x, data_new_y = data_new_x[list_index[i]:], data_new_y[list_index[i]:]
            self.list_x[list_symbols[i]] = data_a
            self.list_y[list_symbols[i]] = data_b

        return self.list_x, self.list_y, self.scaler_x, self.scaler_y, self.train_map

    def getTrainData(self, stockinfo: stockInfo) -> (np.array, np.array):
        self.allScaler()
        symbol= stockinfo.symbol
        x, y = self.list_x[symbol], self.list_y[symbol]
        train_x, _ = self.dataSequence(x, nor=False)
        train_y, _ = self.dataSequence_y(y, nor=False)
        if self.only_untrain:
            index = self. trian_map[symbol]
            if index >= len(train_x):  # 起始训练数据超出是长度,没有数据
                return None
            train_x, train_y = train_x[index:], train_y[index:]
        return train_x,train_y,self.scaler_x,self.scaler_y

class prepareLogistics(prepareBase):
    def __init__(self):
        super(prepareLogistics,self).__init__()
        self.SEQUENCE_LEN=50
        time.sleep(10)
        pass

class prepareClassify(prepareBase):

    def __init__(self):
        super(prepareClassify,self).__init__()
        time.sleep(10)
        pass

    def dataProcess(self,df:pd.DataFrame)-> (np.array,np.array):
        """
        :param df: 包含stock信息的 dataframe
        :return:
        """
        # 4.对数据按照表格拆分
        column_x, column_y = columnSplit(self.verify)
        x, y_t = np.array( df[column_x]),np.array( df[column_y])
        # 5拼凑数据 x的最后一行没有预测值 y的第一行没有 训练值
        y_value= [ y_t[i+1][0]-y_t[i][0] for i in range(len(y_t)-1)]
        y_value= np.int32( np.array( y_value)> 0).reshape(-1, 1)
        if self.verify:
            y_date= [str(y_t[i][1]) + "->" + str(y_t[i + 1][1]) + " :" + str(y_t[i + 1][0]) + "-" + str(y_t[i][0]) for i in range(len(y_t)-1)]
            y_date  = np.array(y_date).reshape(-1, 1)
            y_value = np.concatenate(( y_date ,y_value),axis=1)
        return x[: len(x)-1], y_value

    def dataSequence_y(self,y:pd.DataFrame,nor :bool =False )-> np.array:
        data_all = np.array(y)
        data = []
        for i in range(len(data_all) -self.SEQUENCE_LEN  + 1):
            data.append(data_all[i + self.SEQUENCE_LEN - 1])
        return np.array(data) ,None

class prepareLogistics_Ex(prepareBase):
    def __init__(self):
        super(prepareLogistics_Ex,self).__init__()
        self.SEQUENCE_LEN=50
        self.tuple_x=3
        stock_base = get_info.stock_basic()
        self.area_vec,self.indu_vec ,self.area_shape,self.indu_shape= stock_base.getExVec()
        time.sleep(10)
        pass

    def getInputShape(self):
        daily_shape = super().getInputShape()
        area_shape = (self.area_shape[1],)
        indu_shape =  (self.indu_shape[1],)
        return daily_shape,area_shape,indu_shape

    def getTrainData(self, stockinfo: stockInfo) -> (np.array, np.array):
        data= super().getTrainData(stockinfo)
        if data  is None:
            return  None
        train_x, train_y, self.scaler_x, self.scaler_y = data
        sample_len = train_x.shape[0]
        ex_area ,ex_indu=self.area_vec[stockinfo.symbol],self.indu_vec[stockinfo.symbol]
        ex_area,ex_indu = np.tile(ex_area,(sample_len,1)), np.tile(ex_indu,(sample_len,1))

        train_x =train_x,ex_area,ex_indu

        return train_x,train_y,self.scaler_x,self.scaler_y

if __name__ == "__main__":
    # getTrainData(1)
    # allScaler()
    # for data in dataGenerator(0,True):
    #     print("data")
    pre =prepareClassify()
    pre.test()

trainning.py

import sys
# print(sys.executable)
import numpy as np
import matplotlib.pyplot as plt
import gc

import tensorflow as tf
from tensorflow.keras.callbacks import TensorBoard
from tensorflow.keras.callbacks import ModelCheckpoint
import datetime
import tensorflow.keras as keras
import prepare

import stock_sql
from abc import ABCMeta, abstractmethod
import time

import pickle

# gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
# cpus = tf.config.experimental.list_physical_devices(device_type='CPU')
# print(gpus)
# print(cpus)
#
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
# tf.config.experimental.set_visible_devices(devices=gpus[0], device_type='GPU')
# tf.config.experimental.set_visible_devices(devices=cpus[0], device_type='CPU')
# tf.config.experimental.set_virtual_device_configuration(
#     gpus[1],
#     [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=1536)]
# )

weight_file = "".join(["model/weights/my_weight"])
weight_file_1 = "".join(["model/weights/my_weight_1"])
index_file = "index.txt"
log_dir = "".join(
    ["log\\model_train\\", datetime.datetime.now().strftime("%Y%m%d-%H%M%S")])
model = None

def getTrainIndex():
    with open(index_file, "r") as f:
        index = f.readline()
    return int(index)

def saveTrainIndex(index):
    with open(index_file, "w") as f:
        f.write(str(index))
        f.flush()

def build_model(shape):
    model = model_1(shape)
    keras.utils.plot_model(model, 'picture/multi_model.png', show_shapes=True)
    return model

def model_2(shape):
    model = keras.Sequential()
    model.add(keras.Input(shape=shape))
    model.add(keras.layers.LSTM(
        units=50, activation='tanh', return_sequences=True))
    model.add(keras.layers.LSTM(
        units=50, activation='tanh', return_sequences=True))
    model.add(keras.layers.Dense(units=10, activation="tanh"))
    model.add(keras.layers.Dense(units=1))
    model.compile(optimizer=keras.optimizers.Adam(),
                  loss="mse",
                  metrics=[keras.metrics.mae])
    model.summary()
    return model

def testing(stockInfo):
    """

    :param stockInfo:{symbol: XXX,name}
    :return:
    """
    return logisticsTrain().predict(stockInfo)

def showData(test_y, predict_y):
    x = [i for i in range(len(predict_y))]
    plt.plot(x, predict_y, "b*--")
    plt.plot(x, test_y, "gv--")
    # plt.plot(test_y)
    plt.show()

def template():
    # 构建一个根据文档内容、标签和标题,预测文档优先级和执行部门的网络
    # 超参
    num_words = 2000
    num_tags = 12
    num_departments = 4

    # 输入
    body_input = keras.Input(shape=(None,), name='body')
    title_input = keras.Input(shape=(None,), name='title')
    tag_input = keras.Input(shape=(num_tags,), name='tag')

    # 嵌入层
    body_feat = keras.layers.Embedding(num_words, 64)(body_input)
    title_feat = keras.layers.Embedding(num_words, 64)(title_input)

    # 特征提取层
    body_feat = keras.layers.LSTM(32)(body_feat)
    title_feat = keras.layers.LSTM(128)(title_feat)
    features = keras.layers.concatenate([title_feat, body_feat, tag_input])

    # 分类层
    priority_pred = keras.layers.Dense(
        1, activation='sigmoid', name='priority')(features)
    department_pred = keras.layers.Dense(
        num_departments, activation='softmax', name='department')(features)

    # 构建模型
    model = keras.Model(inputs=[body_input, title_input, tag_input],
                        outputs=[priority_pred, department_pred])
    # model.summary()
    keras.utils.plot_model(
        model, 'picture/template_model.png', show_shapes=True)

def logisticsMode(shape):
    model = keras.Sequential()
    model.add(keras.Input(shape=shape))
    #model.add(keras.layers.Dense(units=100, activation="tanh"))
    model.add(keras.layers.LSTM(units=500, activation='tanh', return_sequences=True))
    model.add(keras.layers.LSTM(units=500, activation='tanh', return_sequences=True))
    model.add(keras.layers.LSTM(units=200, activation='tanh', return_sequences=True))
    model.add(keras.layers.LSTM(units=200, activation='tanh', return_sequences=False))
    model.add(keras.layers.Dense(units=200, activation="tanh"))
    model.add(keras.layers.Dense(units=20, activation="tanh"))
    model.add(keras.layers.Dense(units=1))
    model.compile(optimizer=keras.optimizers.Adam(),
                  loss="mse",
                  metrics=[keras.metrics.mae])
    model.summary()
    return model

def classifyModel(shape):

    inn = keras.Input(shape=shape)
    lstm1 = keras.layers.LSTM(units=500, activation='tanh', return_sequences=True)(inn)
    lstm2 = keras.layers.LSTM(units=500, activation='tanh', return_sequences=True)(lstm1)
    lstm3 = keras.layers.LSTM(units=200, activation='tanh', return_sequences=True)(lstm2)
    lstm4 = keras.layers.LSTM(units=50, activation='tanh', return_sequences=True)(lstm3)
    flatten = keras.layers.Flatten()(lstm4)
    Dense1 = keras.layers.Dense(units=200, activation="relu")(flatten)
    ott = keras.layers.Dense(units=3)(Dense1)

    model = keras.Model(inputs=inn, outputs=ott)
    model.compile(optimizer='adam',
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy'])
    model.summary()
    # keras.utils.plot_model(model,"picture/classify_model.png",show_shapes=True)
    return model

def classifyModel_1(shape):
    inn = keras.Input(shape=shape)
    lstm1 = keras.layers.LSTM(units=500, activation='tanh', return_sequences=True)(inn)
    lstm2 = keras.layers.LSTM(units=500, activation='tanh', return_sequences=True)(lstm1)
    lstm3 = keras.layers.LSTM(units=200, activation='tanh', return_sequences=True)(lstm2)
    lstm4 = keras.layers.LSTM(units=200, activation='tanh', return_sequences=True)(lstm3)
    lstm5 = keras.layers.LSTM(units=100, activation='tanh', return_sequences=True)(lstm4)
    flatten = keras.layers.Flatten()(lstm5)
    Dense1 = keras.layers.Dense(units=200, activation="relu")(flatten)
    Dense2 = keras.layers.Dense(units=100,activation="relu")(Dense1)
    ott =  keras.layers.Dense(units= 3)(Dense2)

    model = keras.Model(inputs=inn, outputs=ott)
    model.compile(optimizer='adam',
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy'])
    model.summary()
    # keras.utils.plot_model(model,"picture/classify_model.png",show_shapes=True)
    return model

def logisticsModel_Ex(daily_shape,area_shape,indu_shape):
    daily_input = keras.Input(shape= daily_shape)
    area_input = keras.Input(shape= area_shape)
    indu_input = keras.Input(shape= indu_shape)

    lstm1 = keras.layers.LSTM(units=500, activation="relu", return_sequences=True)(daily_input)
    lstm2 = keras.layers.LSTM(units=200, activation="relu", return_sequences=True)(lstm1)
    lstm3 = keras.layers.LSTM(units=200, activation="relu", return_sequences=True)(lstm2)
    lstm4 = keras.layers.LSTM(units=200, activation="relu", return_sequences=False)(lstm3)

    # dense_area = keras.layers.Dense(units=200, activation="relu")(area_input)
    # dense_indu = keras.layers.Dense(units=200, activation="relu")(indu_input)

    flatten =keras.layers.concatenate([lstm4,area_input,indu_input])

    dense1 = keras.layers.Dense(units=200, activation="relu")(flatten)
    dense2 = keras.layers.Dense(units=100, activation ="relu")(dense1)
    dense3 = keras.layers.Dense(units=100, activation= "relu")(dense2)
    ott = keras.layers.Dense(units=1)(dense3)

    model =keras.Model(inputs=[daily_input,area_input,indu_input],outputs=ott)

    model.compile(optimizer=keras.optimizers.Adam(),
                  loss="mse",
                  metrics=[keras.metrics.mae])
    model.summary()
    return model

class trainBase:
    def __init__(self):
        self.dataPre:prepare.prepareBase =None
        self.model:keras.Model=None
        pass

    @abstractmethod
    def buildModel(self):
        pass

    def trainAll(self):
        for data in self.dataPre.dataGenerator(0):
            train_x, train_y, stock_list = data
            print(train_x[0].shape)
            print(train_x[1].shape)
            print(train_x[2].shape)
            print(train_y.shape)
            history = self.model.fit(train_x, train_y, batch_size=self.batch_size,
                                     epochs=self.epochs, validation_split=0.05)
            self.model.save_weights(self.file_path)
            stock_sql.updateTrianList(stock_list)
            del train_x, train_y, stock_list
            gc.collect()
            gc.collect()

    def predict(self, stockinfo: prepare.stockInfo) -> (np.array, np.array):
        data = stock_sql.getRecordData(stockinfo.symbol)
        if data:
            return data[0], data[1]
        data = self.dataPre.getTestData(stockinfo)
        if data is None:
            return np.array([]), np.array([])
        test_x, real_y,scaler_x, scaler_y = data[0], data[1], data[2], data[3]
        predict_y = self.model.predict(test_x)
        if scaler_y is not  None:
            real_y = scaler_y.inverse_transform(real_y)
            predict_y = scaler_y.inverse_transform(predict_y.astype("float64"))
        real_y, predict_y = np.around( real_y, decimals=2), np.around(predict_y, decimals=2)

        stock_sql.SaveRecordData(stockinfo, real_y, predict_y)
        return real_y, predict_y

class logisticsTrain(trainBase):
    def __init__(self):
        self.file_path = "".join(["model/weights/my_weight"])
        self.batch_size = 1024
        self.epochs = 20
        self.buildModel()

    def buildModel(self):
        print("parepareLogistics")
        self.dataPre = prepare.logisticsAllScaler()

        self.model = logisticsMode(self.dataPre.getInputShape())
        try:
            print("load_weights from " + self.file_path)
            self.model.load_weights(self.file_path)
            print(" success ")
        except:
            print("load_weight failed")
            pass

        time.sleep(10)

class classifyTrain(trainBase):
    def __init__(self):
        self.file_path = "".join(["model/weights/classify_weight"])
        self.batch_size = 1024
        self.epochs = 5
        self.buildModel()

    def buildModel(self):
        self.dataPre = prepare.prepareClassify()
        print("prepareClassify")
        self.model = classifyModel(self.dataPre.getInputShape())
        try:
            print("load_weights from " + self.file_path)
            self.model.load_weights(self.file_path)
            print(" success ")
        except:
            print("load_weight failed")
            pass
        time.sleep(10)

class classifyTrain_1(trainBase):
    def __init__(self):
        self.file_path = "".join(["model/weights/classify_weight_1"])
        self.batch_size = 1024
        self.epochs = 5
        self.buildModel()

    def buildModel(self):
        self.dataPre = prepare.prepareClassify()
        print("prepareClassify")
        self.model = classifyModel_1(self.dataPre.getInputShape())
        try:
            print("load_weights from " + self.file_path)
            self.model.load_weights(self.file_path)
            print(" success ")
        except:
            print("load_weight failed")
            pass
        time.sleep(10)

class logisticsTrain_Ex(trainBase):

    def __init__(self):
        self.file_path ="".join(["model/weights/logisticsEx_weight"])
        self.batch_size = 512
        self.epochs = 5
        self.buildModel()

    def buildModel(self):
        self.dataPre = prepare.prepareLogistics_Ex()
        print("prepareClassifyEx")
        daily_shape, area_shape, indu_shape = self.dataPre.getInputShape()
        self.model = logisticsModel_Ex(daily_shape, area_shape, indu_shape)
        try:
            print("load_weights from " + self.file_path)
            self.model.load_weights(self.file_path)
            print(" success ")
        except:
            print("load_weight failed")
            pass
        time.sleep(10)
        pass

if __name__ == '__main__':
    # stock_sql.initTrainDate()

    train = logisticsTrain_Ex()
    train.dataPre.only_untrain=False
    train.trainAll()

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注