{"id":110,"date":"2020-02-04T12:33:55","date_gmt":"2020-02-04T04:33:55","guid":{"rendered":"https:\/\/www.nickchan.cn\/?p=110"},"modified":"2024-07-11T02:18:02","modified_gmt":"2024-07-10T18:18:02","slug":"%e8%82%a1%e7%a5%a8%e6%a8%a1%e5%9e%8b%e9%a2%84%e6%b5%8b","status":"publish","type":"post","link":"https:\/\/www.nickchan.cn\/index.php\/2020\/02\/04\/%e8%82%a1%e7%a5%a8%e6%a8%a1%e5%9e%8b%e9%a2%84%e6%b5%8b\/","title":{"rendered":"\u80a1\u7968\u6a21\u578b\u9884\u6d4b"},"content":{"rendered":"<h1>\u7b80\u4ecb\uff1a<\/h1>\n<p>\u5904\u7406\u8fc7\u7a0b\u5927\u4f53\u5982\u4e0b\uff1a<br \/>\n1\u3001\u4f7f\u7528tushare\u83b7\u53d6stock\u4fe1\u606f<br \/>\n2\u3001\u5bf9\u6570\u636e\u8fdb\u884c\u5904\u7406\uff0c\u505a\u597dtrain_x\u548ctrain_y\u7684\u5bf9\u5e94\u5173\u7cfb<br \/>\n3\u3001\u8bad\u7ec3\u548c\u9884\u6d4b \uff0c\u7f51\u7edc\u662f\u5b66\u4e60x\u4e5fy\u4e4b\u95f4\u7684\u6620\u5c04\u5173\u7cfb\uff0c\u7f51\u7edc\u8981\u4e0e\u6570\u636e\u5339\u914d  <\/p>\n<p>1\u3001get_info.py \u7528\u4e8e\u4ece\u83b7\u53d6stock\u4fe1\u606f\uff0c\u76ee\u524d\u53ea\u6709\u65e5\u7ebf\u6570\u636e\uff0c\u540e\u9762\u4f1a\u589e\u52a0\uff08tushare\u79ef\u5206\u4e0d\u591f\uff0c\u6709\u4e9b\u6570\u636e\u83b7\u53d6\u4e0d\u5230\uff0c\u5927\u5bb6\u6ce8\u518c\u4e0b\u7ed9\u70b9\u79ef\u5206\uff09<br \/>\n2\u3001stock_sql.py \u7528\u4e8e\u5c06\u90e8\u5206\u4fe1\u606f\u8bb0\u5f55\u5230\u6570\u636e\u5e93\uff0c\u65b9\u4fbf\u67e5\u8be2\u68c0\u7d22\u3002\u6570\u636e\u5e93\u4f7f\u7528\u7684sqlite3<br \/>\n3\u3001prepare.py \u7528\u4e8e\u5bf9\u6570\u636e\u8fdb\u884c\u5904\u7406\uff0c\u751f\u6210train_x\uff0ctrain_y\u7684\u5bf9\u5e94\u5173\u7cfb\uff0c\u6ee1\u8db3\u7f51\u7edc\u8bad\u7ec3\u9700\u8981\u3002NN\u53ef\u4ee5\u5b66\u4e60\u6620\u5c04\u89c4\u5f8b\uff0c\u8fdb\u884c\u9884\u6d4b\u3002<br \/>\n4\u3001trainning.py \u7f51\u7edc\u6a21\u578b\uff0c\u5e76\u8fdb\u884c\u8bad\u7ec3\u3002<br \/>\n5\u3001evaluate.py \u5bf9\u6a21\u578b\u8fdb\u884c\u7b80\u5355\u7684\u8bc4\u4f30\u3002<br \/>\n6\u3001server.py flask\u505a\u7684\u540e\u53f0\u7528\u4e8e\u6570\u636e\u5c55\u793a\u3002<\/p>\n<h1>\u4ee3\u7801\u5b9e\u73b0\uff1a<\/h1>\n<h1>get_info.py<\/h1>\n<pre><code class=\"language-python\">from concurrent.futures import ProcessPoolExecutor, as_completed\n\nfrom pandas import DataFrame\n\nimport stock_sql\nimport pandas as pd\nimport numpy as np\nimport datetime\nimport time\nimport tushare as ts\n\nimport jieba\nimport re\nfrom gensim.test.utils import common_texts, get_tmpfile\nfrom gensim.models import Word2Vec,word2vec\nfrom gensim.models.doc2vec import Doc2Vec, TaggedDocument\n\nfrom sklearn.preprocessing import OneHotEncoder\n\n#path = get_tmpfile(&quot;word2vec.model&quot;)\n\nprint(ts.__version__)\n\nts.set_token(&quot;d1af48f518c17415b1b98b2ce84ab7b1a0025adfdde78e22513b31ec&quot;)\npro = ts.pro_api()\n\nd = {\n    &quot;ts_code&quot;: &quot;TS\u4ee3\u7801&quot;,\n     &quot;symbol&quot;: &quot;\u80a1\u7968\u4ee3\u7801&quot;,\n     &quot;name&quot;: &quot;\u80a1\u7968\u540d\u79f0&quot;,\n     &quot;area&quot;: &quot;\u6240\u5728\u5730\u57df&quot;,\n     &quot;industry&quot;: &quot;\u6240\u5c5e\u884c\u4e1a&quot;,\n     &quot;fullname&quot;: &quot;\u80a1\u7968\u5168\u79f0&quot;,\n     &quot;enname&quot;: &quot;\u82f1\u6587\u5168\u79f0&quot;,\n     &quot;market&quot;: &quot;\u5e02\u573a\u7c7b\u578b(\u4e3b\u677f\/\u4e2d\u5c0f\u677f\/\u521b\u4e1a\u677f\/\u79d1\u521b\u677f)&quot;,\n     &quot;exchange&quot;: &quot;\u4ea4\u6613\u6240\u4ee3\u7801&quot;,\n     &quot;curr_type&quot;: &quot;\u4ea4\u6613\u8d27\u5e01&quot;,\n     &quot;list_status&quot;: &quot;\u4e0a\u5e02\u72b6\u6001:L\u4e0a\u5e02 D\u9000\u5e02 P\u6682\u505c\u4e0a\u5e02&quot;,\n     &quot;list_date&quot;: &quot;\u4e0a\u5e02\u65e5\u671f&quot;,\n     &quot;delist_date&quot;: &quot;\u9000\u5e02\u65e5\u671f&quot;,\n     &quot;is_hs&quot;: &quot;\u662f\u5426\u6caa\u6df1\u6e2f\u901a\u6807\u7684\uff0cN\u5426 H\u6caa\u80a1\u901a S\u6df1\u80a1\u901a&quot;,\n     }\n\ndaily = {\n    &quot;ts_code&quot;: &quot;\u80a1\u7968\u4ee3\u7801&quot;,\n    &quot;trade_date&quot;: &quot;\u4ea4\u6613\u65e5\u671f&quot;,\n    &quot;open&quot;: &quot;\u5f00\u76d8\u4ef7&quot;,\n    &quot;high&quot;: &quot;\u6700\u9ad8\u4ef7&quot;,\n    &quot;low&quot;: &quot;\u6700\u4f4e\u4ef7&quot;,\n    &quot;close&quot;: &quot;\u6536\u76d8\u4ef7&quot;,\n    &quot;pre_close&quot;: &quot;\u6628\u6536\u4ef7&quot;,\n    &quot;change&quot;: &quot;\u6da8\u8dcc\u989d&quot;,\n    &quot;pct_chg&quot;: &quot;\u6da8\u8dcc\u5e45(\u672a\u590d\u6743)&quot;,\n    &quot;vol&quot;: &quot;\u6210\u4ea4\u91cf(\u624b)&quot;,\n    &quot;amount&quot;: &quot;\u6210\u4ea4\u989d(\u5343\u5143)&quot;,\n}\n\ndef dateRange(beginDate, endDate):\n    dates = []\n    dt = datetime.datetime.strptime(beginDate, &quot;%Y%m%d&quot;)\n    date = beginDate[:]\n    while date &lt;= endDate:\n        dates.append(date)\n        dt = dt + datetime.timedelta(1)\n        date = dt.strftime(&quot;%Y%m%d&quot;)\n    return dates\n\nclass stock_basic:\n    &quot;&quot;&quot;\n    \u80a1\u7968\u57fa\u672c\u4fe1\u606f\n    &quot;&quot;&quot;\n    basic_info = {\n        &quot;ts_code&quot;: &quot;TS\u4ee3\u7801&quot;,\n        &quot;symbol&quot;: &quot;\u80a1\u7968\u4ee3\u7801&quot;,\n        &quot;name&quot;: &quot;\u80a1\u7968\u540d\u79f0&quot;,\n        &quot;area&quot;: &quot;\u6240\u5728\u5730\u57df&quot;,\n        &quot;industry&quot;: &quot;\u6240\u5c5e\u884c\u4e1a&quot;,\n        &quot;fullname&quot;: &quot;\u80a1\u7968\u5168\u79f0&quot;,\n        &quot;enname&quot;: &quot;\u82f1\u6587\u5168\u79f0&quot;,\n        &quot;market&quot;: &quot;\u5e02\u573a\u7c7b\u578b(\u4e3b\u677f\/\u4e2d\u5c0f\u677f\/\u521b\u4e1a\u677f\/\u79d1\u521b\u677f)&quot;,\n        &quot;exchange&quot;: &quot;\u4ea4\u6613\u6240\u4ee3\u7801&quot;,\n        &quot;curr_type&quot;: &quot;\u4ea4\u6613\u8d27\u5e01&quot;,\n        &quot;list_status&quot;: &quot;\u4e0a\u5e02\u72b6\u6001:L\u4e0a\u5e02 D\u9000\u5e02 P\u6682\u505c\u4e0a\u5e02&quot;,\n        &quot;list_date&quot;: &quot;\u4e0a\u5e02\u65e5\u671f&quot;,\n        &quot;delist_date&quot;: &quot;\u9000\u5e02\u65e5\u671f&quot;,\n        &quot;is_hs&quot;: &quot;\u662f\u5426\u6caa\u6df1\u6e2f\u901a\u6807\u7684\uff0cN\u5426 H\u6caa\u80a1\u901a S\u6df1\u80a1\u901a&quot;,\n    }\n\n    def __init__(self):\n        self.file_path =&quot;data\/stock_basic.csv&quot;\n        self.record_path= &quot;data\/data.csv&quot;\n\n    def getBasicInfo(self):\n        # \u83b7\u53d6\u4fe1\u606f\n        fields = list(self.basic_info.keys())\n        headers = list(self.basic_info.values())\n        data: DataFrame = pro.query(&#039;stock_basic&#039;, exchange=&#039;&#039;,\n                                    list_status=&#039;L&#039;, fields=fields)\n        # print(data.values)\n        data.columns = headers\n        data.to_csv(self.file_path, header=True, encoding=&#039;utf-8&#039;, index=False)\n\n    def recordDateInfo(self):\n        # \u521d\u59cb\u5316\u8ddf\u65b0\u65f6\u95f4\n        data = pd.read_csv(&quot;data\/stock_basic.csv&quot;, dtype=object)\n        # a \u80a1\u7968\u7684\u5168\u90e8\u4fe1\u606f \uff0cb \u8981\u53bb\u6389\u7684\u51e0\u5217\n        a = list(d.values())\n        b = [&quot;TS\u4ee3\u7801&quot;, &quot;\u80a1\u7968\u540d\u79f0&quot;, &quot;\u4e0a\u5e02\u65e5\u671f&quot;, &quot;\u9000\u5e02\u65e5\u671f&quot;]\n        today = datetime.date.today().strftime(&quot;%Y%m%d&quot;)\n\n        header = [val for val in a if val not in b]\n        print(header)\n        data_1 = data.drop(header, axis=1)\n        data_1.insert(loc=len(data_1.columns), column=&quot;\u66f4\u65b0\u65f6\u95f4&quot;,\n                      value=today, allow_duplicates=False)\n\n        for row in data_1.values:\n            row[4] = row[2]\n\n        data_1.to_csv(&quot;data\/data.csv&quot;, header=True, encoding=&#039;utf-8&#039;, index=False)\n\n    def getExVec(self):\n        df = stock_sql.getStockFrame(headers=[&quot;symbol&quot;, &quot;name&quot;, &quot;area&quot;, &quot;industry&quot;])\n        df.sort_values(by=[&quot;symbol&quot;], inplace=True, ignore_index=True)\n        df.fillna(&quot; &quot;,inplace=True)\n        area_vec = OneHotEncoder(sparse=False).fit_transform(df[[&quot;area&quot;]])\n        indu_vec = OneHotEncoder(sparse=False).fit_transform(df[[&quot;industry&quot;]])\n        # data = pd.read_csv(&quot;data\/stock_basic.csv&quot;, dtype=object)\n        # aa = OneHotEncoder().fit_transform(data[[&quot;\u6240\u5728\u5730\u57df&quot;]])\n        data_area= {}\n        data_indu={}\n        symbols = np.array(df[[&quot;symbol&quot;]])\n        # areas =np.array(df[[&quot;area&quot;]]).reshape(-1)\n        # data_t={}\n        for i in range(len(symbols)):\n            data_area[symbols[i][0]] = area_vec[i]\n            data_indu[symbols[i][0]] = indu_vec[i]\n        return  data_area, data_indu, area_vec.shape,indu_vec.shape\n\nclass stock_daily(object):\n\n    def __init__(self):\n        pass\n\n    def updateDailyInfo(self):\n        &quot;&quot;&quot;\n        \u66f4\u65b0\u6240\u6709 stock \u7684\u65e5\u4fe1\u606f\n        :return:\n        &quot;&quot;&quot;\n        df = stock_sql.getTradeDateList()\n        # \u6309\u7167symbol \u5347\u7ea7\n        # df.sort_values(by=[&quot;symbol&quot;], inplace=True, ignore_index=True)\n        stock_list = np.array(df)\n        self. updateDailyList(stock_list)\n        # print(stock_list)\n\n    def createDateTable(self):\n        &quot;&quot;&quot;\n        \u521b\u5efa\u6570\u636e\u5e93\u7528\u6765\u8bb0\u5f55 \u66f4\u65b0\u7684\u65e5\u671f\u4fe1\u606f ,\u5e76\u8fdb\u884c\u6839\u636e stockBasic\u4fe1\u606f\u8fdb\u884c\u521d\u59cb\u5316\n        :return:\n        &quot;&quot;&quot;\n        stock_sql.createDateTable()\n        stock_sql.initDateTable()\n\n    def getDailyInfo(self, symbol, ts_code, start_time:str, end_time:str) -&gt; pd.DataFrame:\n        &quot;&quot;&quot;\n        \u83b7\u53d6stock \u65e5\u4fe1\u606f\n        :param symbol: stock \u4ee3\u7801\n        :param ts_code:\n        :param start_time:  \u8d77\u59cb\u65e5\u671f\n        :param end_time:    \u7ed3\u675f\u65e5\u671f\n        :return:   daily \u4fe1\u606f\n        &quot;&quot;&quot;\n        df_list = []\n        while start_time != end_time:\n            df_t = pro.daily(\n                ts_code=ts_code, start_date=start_time, end_date=end_time)\n            df_list.append(df_t)\n            end_time = df_t.tail(1)[&quot;trade_date&quot;].values[0]\n\n        if len(df_list) != 0:\n            data = pd.concat(df_list, join=&quot;inner&quot;)\n        else:\n            data = None\n\n        return data\n\n    def dailyInfoConcat(self, symbol: int, data:pd.DataFrame):\n        &quot;&quot;&quot;\n        \u5c06\u65b0\u589e\u7684\u65e5\u65b0\u589e\u4fe1\u606f\u4e0e\u539f\u8bb0\u5f55\u7684\u4fe1\u606f\u5408\u5e76\u4fdd\u5b58\n        :param symbol: stock \u4ee3\u7801\n        :param data: \u65b0\u589e\u6570\u636e\n        :return: \u6700\u65b0\u66f4\u65b0\u65e5\u671f\n        &quot;&quot;&quot;\n        path = stock_sql.getFilePath(symbol)\n        df = stock_sql.getStockData(path)\n        data.trade_date = data.trade_date.astype(&quot;int64&quot;)\n        # \u65b0\u6570\u636e\u653e\u5728\u524d\n        data_n = pd.concat([data, df], join=&#039;inner&#039;, ignore_index=True)\n        # 1.\u6570\u636e\u53bb\u91cd\n        data_n.drop_duplicates(subset=[&quot;trade_date&quot;], keep=&#039;first&#039;, inplace=True)\n        data_n.to_csv(path, header=True, encoding=&#039;utf-8&#039;, index=False)\n        return data_n.head(1)[&quot;trade_date&quot;].values[0]\n\n    def recordLastDate(self,symbol,data):\n        &quot;&quot;&quot;\n        \u8bb0\u5f55\u6700\u540edailyinfo\u7684\u6700\u540e\u66f4\u65b0\u65e5\u671f\n        :param symbol: stock \u4ee3\u7801\n        :param data: \u6700\u540e\u66f4\u65b0\u65e5\u671f\n        :return:  none\n        &quot;&quot;&quot;\n        stock_sql.updateTradeDate(symbol, data)\n\n    def updateDailyList(self, stockList):\n        &quot;&quot;&quot;\n        \u66f4\u65b0stocklist\u4e2d\u7684stock \u65e5\u4fe1\u606f\n        :param stockList:\u8f93\u5165\u7684\u80a1\u7968\u5217\u8868\u5305\u542b symbol name tscode  lastdate\n        :return:\n        &quot;&quot;&quot;\n        today = datetime.date.today().strftime(&quot;%Y%m%d&quot;)\n        for item in stockList:\n            # time.sleep(.3)\n            print(item)\n            data = self.getDailyInfo(item[0], item[2], item[3], &quot;20200312&quot;)\n            if data is not None:\n                last_d = self.dailyInfoConcat(item[0], data)\n                self.recordLastDate(item[0], last_d)\n\n    def updateDailyListMutliple(self,stockList):\n        &quot;&quot;&quot;\n        \u66f4\u65b0stocklist\u4e2d\u7684stock \u65e5\u4fe1\u606f (\u591a\u7ebf\u7a0b)\n        :param stockList:\u8f93\u5165\u7684\u80a1\u7968\u5217\u8868\u5305\u542b symbol name tscode  lastdate\n        :return:\n        &quot;&quot;&quot;\n        #\u5bf9\u5217\u8868\u8fdb\u884c\u62c6\u5206\n        stock_split = stock_sql.arr_split( stockList ,500)\n        executor = ProcessPoolExecutor(max_workers=8)\n        all_task = [executor.submit( self.updateDailyList, list_s) for list_s in stock_split]\n        for future in as_completed(all_task) :\n            res = future.result()\n\nclass stock_company(object):\n    &quot;&quot;&quot;\n    \u80a1\u7968\u76f8\u5173\u884c\u4e1a\u4fe1\u606f\n    &quot;&quot;&quot;\n\n    def __init__(self):\n        self.file_path = &quot;data\/stock_company.csv&quot;\n        self.company_info = {\n            &quot;ts_code&quot;: &quot;\u80a1\u7968\u4ee3\u7801&quot;,\n            &quot;province&quot;: &quot;\u6240\u5728\u7701\u4efd&quot;,\n            &quot;city&quot;: &quot;\u6240\u5728\u57ce\u5e02&quot;,\n            &quot;main_business&quot;: &quot;\u4e3b\u8425\u4e1a\u52a1&quot;,\n            &quot;business_scope&quot;: &quot;\u7ecf\u8425\u8303\u56f4&quot;\n        }\n\n    def getCompanyInfo(self):\n        df:DataFrame =pro.stock_company(fields =  list( self.company_info.keys()))\n        df.to_csv(self.file_path,header=True, encoding=&#039;utf-8&#039;, index=False)\n        print(df)\n\n    def getCompanyData(self):\n        drops = [&quot;ts_code&quot;]\n        columns =[x for x in self.company_info.keys() if x not in drops ]\n        df :DataFrame = pd.read_csv(self.file_path,index_col=False,na_values=&quot;&quot;)\n        df.sort_values(by=[&quot;ts_code&quot;], ascending=True, ignore_index=True, inplace=True)\n        df = df[columns]\n        return df\n\n    def removePunctuation(self,text):\n        text = re.sub(r&#039;[{}]+&#039;.format(&#039;!,;:?()[]\\n\uff0c\u3002\uff08\uff09\uff1a\uff1b\u3001\\&#039;&#039;), &#039; &#039;, text)\n        return text.strip().lower()\n\n    def word_cut_test(self):\n        sentences = [&quot;\u5438\u6536\u516c\u5171\u5b58\u6b3e&quot;,\n                     &quot;2015\u5e74\u6211\u6bd5\u4e1a\u4e8e\u897f\u5b89\u79d1\u6280\u5927\u5b66&quot;,\n                     &quot;2015\u5e74\u6211\u6bd5\u4e1a\u4e8e\u897f\u5b89\u7535\u5b50\u79d1\u6280\u5927\u5b66&quot;,\n                     &quot;2015\u5e74\u6211\u6bd5\u4e1a\u4e8e\u897f\u5b89\u5efa\u7b51\u79d1\u6280\u5927\u5b66&quot;,\n                     &quot;2015\u5e74\u6211\u6bd5\u4e1a\u4e8e\u897f\u5b89\u4ea4\u901a\u5927\u5b66&quot;,\n                     &quot;2015\u5e74\u6211\u6bd5\u4e1a\u4e8e\u5317\u4eac\u5927\u5b66&quot;]\n\n        for sentence in sentences:\n            # \u5168\u6a21\u5f0f\n            words = jieba.cut(sentence, cut_all=True)\n            print(&quot;\u5168\u6a21\u5f0f:  %s&quot; % &quot; &quot;.join(words))\n\n            words = jieba.cut(sentence, use_paddle=True)\n            print(&quot;\u65b0\u8bcd\u6a21\u5f0f:  %s&quot; % &quot; &quot;.join(words))\n            # \u9ed8\u8ba4\u7cbe\u786e\u6a21\u5f0f\n            words = jieba.cut(sentence)\n            print(&quot;\u7cbe\u786e\u6a21\u5f0f:  %s&quot; % &quot; &quot;.join(words))\n\n            # \u641c\u7d22\u6a21\u5f0f\n            words = jieba.cut_for_search(sentence)\n            print(&quot;\u641c\u7d22\u6a21\u5f0f:  %s&quot; % &quot; &quot;.join(words))\n\n    def word_cut(self):\n        jieba.enable_paddle()\n        df =self.getCompanyData()\n        df =df[[&quot;business_scope&quot;]]\n        arr = np.array(df)\n        stopwords = self.stopwordslist(&quot;nlp\/stop_words.txt&quot;)\n        with open(&quot;nlp\/cut_words.txt&quot;,&quot;w+&quot;, encoding=&#039;utf-8&#039;) as fw:\n            for i in range(len(arr)):\n                list=self.removePunctuation(arr[i][0])\n                item = jieba.cut(list,use_paddle=True) # \u4f7f\u7528paddle\u6a21\u5f0f\n                santi_words =[x for x in item if len(x) &gt; 1 and x not in stopwords]\n                fw.writelines(santi_words)\n                fw.write(&quot;\\r\\n&quot;)\n                print(santi_words)\n\n    def stopwordslist(self,filepath):\n        stopwords = [line.strip() for line in open(filepath, &#039;r&#039;, encoding=&#039;utf-8&#039;).readlines()]\n        return stopwords\n\n    def word2vec(self):\n        sentences = word2vec.PathLineSentences(&quot;nlp\/cut_words.txt&quot;)\n        model = Word2Vec(sentences, size=20, window=5, min_count=1, workers=4)\n        model.save(&quot;nlp\/word2vec.model&quot;)\n        model = Word2Vec.load(&quot;nlp\/word2vec.model&quot;)\n        # a= model.train([[&quot;\u5438\u6536\u516c\u4f17\u5b58\u6b3e&quot;, &quot;\u5438\u6536\u516c\u4f17\u5b58\u6b3e&quot;]], total_examples=1, epochs=1)\n        vector = model.wv[&#039;\u65b0\u6750\u6599&#039;]\n        a=model.similar_by_vector(vector)\n        print(a)\n        print(vector)\n\n    def doc2vec(self):\n        sentences = word2vec.PathLineSentences(&quot;nlp\/cut_words.txt&quot;)\n        documents = [TaggedDocument(doc, [i]) for i, doc in enumerate(sentences)]\n        # for  i ,doc in  enumerate(sentences):\n        #     ddd = TaggedDocument(doc,[i])\n        #     print(ddd)\n        model = Doc2Vec(documents, vector_size=20, window=2, min_count=1, workers=4)\n        model.save(&quot;nlp\/doc2vec.model&quot;)\n        model = Doc2Vec.load(&quot;nlp\/doc2vec.model&quot;)\n        vector = model.infer_vector([&quot;\u7535\u5668\u5f00\u5173\u96f6\u90e8\u4ef6\u53ca\u9644\u4ef6\u5236\u9020&quot;])\n        model.similar_by_vector(vector)\n        pass\n\nif __name__ == &quot;__main__&quot;:\n    # stock_record()\n    # stock_data()\n    # company= stock_company()\n    # company.doc2vec()\n    basic  =stock_basic()\n    basic.getExVec()\n<\/code><\/pre>\n<h1>stock_sql.py<\/h1>\n<pre><code class=\"language-python\">from datetime import datetime\nimport pandas as pd\nimport numpy as np\nimport sqlite3\nimport os\nimport pickle\n\nd = {&quot;ts_code&quot;: &quot;TS\u4ee3\u7801&quot;,\n     &quot;symbol&quot;: &quot;\u80a1\u7968\u4ee3\u7801&quot;,\n     &quot;name&quot;: &quot;\u80a1\u7968\u540d\u79f0&quot;,\n     &quot;area&quot;: &quot;\u6240\u5728\u5730\u57df&quot;,\n     &quot;industry&quot;: &quot;\u6240\u5c5e\u884c\u4e1a&quot;,\n     &quot;fullname&quot;: &quot;\u80a1\u7968\u5168\u79f0&quot;,\n     &quot;enname&quot;: &quot;\u82f1\u6587\u5168\u79f0&quot;,\n     &quot;market&quot;: &quot;\u5e02\u573a\u7c7b\u578b(\u4e3b\u677f\/\u4e2d\u5c0f\u677f\/\u521b\u4e1a\u677f\/\u79d1\u521b\u677f)&quot;,\n     &quot;exchange&quot;: &quot;\u4ea4\u6613\u6240\u4ee3\u7801&quot;,\n     &quot;curr_type&quot;: &quot;\u4ea4\u6613\u8d27\u5e01&quot;,\n     &quot;list_status&quot;: &quot;\u4e0a\u5e02\u72b6\u6001:L\u4e0a\u5e02 D\u9000\u5e02 P\u6682\u505c\u4e0a\u5e02&quot;,\n     &quot;list_date&quot;: &quot;\u4e0a\u5e02\u65e5\u671f&quot;,\n     &quot;delist_date&quot;: &quot;\u9000\u5e02\u65e5\u671f&quot;,\n     &quot;is_hs&quot;: &quot;\u662f\u5426\u6caa\u6df1\u6e2f\u901a\u6807\u7684\uff0cN\u5426 H\u6caa\u80a1\u901a S\u6df1\u80a1\u901a&quot;,\n     }\n\nDataBase = &quot;data\\my_stock.db&quot;\nRECORDTABLE = &#039;&#039;\n\nclass my_sql:\n\n    def __init__(self, baseName):\n        self.conn = sqlite3.connect(baseName)\n\n    def __enter__(self):\n        # print(&#039;__enter__() is call!&#039;)\n        return self.conn.cursor()\n\n    def __exit__(self, exc_type, exc_value, traceback):\n        self.conn.commit()\n        self.conn.close()\n        # print(&#039;__exit__() is call!&#039;)\n        # print(f&#039;type:{exc_type} : &#039; + f&#039;value:{exc_value}&#039;)\n        # print(f&#039;value:{exc_value}&#039;)\n        # print(f&#039;trace:{traceback}&#039;)\n        return True  # \u5f02\u5e38\u4e0d\u629b\u51fa\n\ndef CreateStockTabel():\n    Sql_str = &#039;CREATE TABLE stock_info (&#039;\n    for item in list(d.keys()):\n        Sql_str += item + &quot; CHAR(50), &quot;\n    Sql_str = Sql_str[:-2]\n    Sql_str += &quot;) ;&quot;\n\n    print(Sql_str)\n    with my_sql(DataBase) as c:\n        c.execute(Sql_str)\n\ndef createRecordTable():\n    Sql_str = &#039;CREATE TABLE  IF NOT EXISTS stock_record (&#039; \\\n              &#039;symbol TEXT PRIMARY KEY, name TEXT, time TEXT, real_data BLOB, predict_data BLOB )&#039;\n    with my_sql(DataBase) as c:\n        c.execute(Sql_str)\n\ndef createDateTable():\n    &quot;&quot;&quot;\n    \u521b\u5efastock_date \u6570\u636e\u8868 \u7528\u6765\u8bb0\u5f55daily  train\u7b49\u65e5\u671f\u4fe1\u606f\n    :return:\n    &quot;&quot;&quot;\n    Sql_str = &#039;CREATE TABLE  IF NOT EXISTS stock_date (&#039; \\\n              &#039;symbol TEXT PRIMARY KEY, name TEXT, ts_code TEXT, &#039; \\\n              &#039;list_date TEXT, trade_date TEXT, train_date TEXT)&#039;\n\n    with my_sql(DataBase) as c:\n        c.execute(Sql_str)\n\ndef saveStockData():\n    &quot;&quot;&quot;\n    \u5c06stock\u7684\u57fa\u672c\u4fe1\u606f\u5b58\u653e\u518d\u6570\u636e\u5e93\n    :return:\n    &quot;&quot;&quot;\n    df = pd.read_csv(&quot;data\/stock_basic.csv&quot;, dtype=object)\n    df.columns = list(d.keys())\n    conn = sqlite3.connect(DataBase)\n    df.to_sql(name=&quot;stock_info&quot;, con=conn, index=False, if_exists=&quot;replace&quot;)\n    conn.close()\n\ndef initDateTable():\n    &quot;&quot;&quot;\n    \u751f\u6210 \u8868\u7528\u6765\u8bb0\u5f55 stock\u4fe1\u606f\u66f4\u65b0\u5230\u54ea\u4e00\u5929\uff0c \u6a21\u578b\u5bf9\u8bad\u7ec3\u5230\u54ea\u4e00\u5929\n    :return:\n    &quot;&quot;&quot;\n    sqlStr = &quot;SELECT symbol, name, ts_code,list_date FROM stock_info &quot;\n    conn = sqlite3.connect(DataBase)\n    df = pd.read_sql(sqlStr, con=conn)\n    conn.close()\n    with my_sql(DataBase) as c:\n        for index, row in df.iterrows():\n            sqlStr = &quot;REPLACE INTO stock_date (symbol, name, ts_code, list_date, info_date, train_date) &quot; \\\n                     &quot;VALUES (?,?,?,?,?,?)&quot;\n            c.execute(sqlStr,(row[&quot;symbol&quot;],row[&quot;name&quot;],row[&quot;ts_code&quot;],\n                              row[&quot;list_date&quot;],row[&quot;list_date&quot;],row[&quot;list_date&quot;]))\n\ndef saveStockPath():\n    &quot;&quot;&quot;\n    \u5c06stock\u7684\u4f4d\u7f6e\u4fe1\u606f\u5b58\u5165\u6570\u636e\u5e93\n    :return:\n    &quot;&quot;&quot;\n    sqlStr = &quot;SElECT name, symbol, ts_code FROM stock_info &quot;\n    conn = sqlite3.connect(DataBase)\n    df = pd.read_sql(sqlStr, con=conn)\n\n    paths = []\n    for item in df.values:\n        path_o = &quot;&quot;.join([&quot;data\/info\/&quot;, item[0].replace(&quot;*&quot;, &quot;&quot;), &quot;-&quot;, item[2], &quot;.csv&quot;])\n        path_w = &quot;&quot;.join([&quot;data\/info\/&quot;, item[1], &quot;.csv&quot;])\n        paths.append(path_w)\n        # os.rename(path_o,path_w)\n    df.drop(columns=[&quot;ts_code&quot;], inplace=True)\n    df.insert(loc=2, column=&quot;path&quot;, value=paths)\n    df.to_sql(name=&quot;stock_path&quot;, con=conn, index=False, if_exists=&#039;replace&#039;)\n    conn.close()\n    print(df)\n\ndef initTrainDate():\n    &quot;&quot;&quot;\n    \u521d\u59cb\u5316 \u8bad\u7ec3\u65e5\u671f\n    :param symbol:\n    :param date:\n    :return:\n    &quot;&quot;&quot;\n    sqlStr = &quot;UPDATE stock_date SET train_date = list_date&quot;\n    with my_sql(DataBase) as c:\n        c.execute(sqlStr)\n    print(&quot;initTrainDate&quot;)\n\ndef updateTradeDate(symbol,date):\n    &quot;&quot;&quot;\n    \u66f4\u65b0stock\u7684 \u65e5\u671f\u4fe1\u606f \u8bb0\u5f55\u65e5\u7ebf\u66f4\u65b0\u5230\u54ea\u90a3\u4e00\u5929\n    :param symbol: stock \u4ee3\u7801\n    :param date:   \u4ea4\u6613\u65e5\u671f\n    :return:\n    &quot;&quot;&quot;\n    sqlStr = &quot;UPDATE stock_date SET trade_date =&#039;{}&#039; WHERE symbol=&#039;{}&#039;&quot;.format(date, symbol)\n    with my_sql(DataBase) as c:\n        c.execute(sqlStr)\n\ndef updateTrianList(stockList):\n    &quot;&quot;&quot;\n    \u66f4\u65b0\u80a1\u7968\u7684 \u8bad\u7ec3\u65e5\u671f\n    :param dateList:\u8f93\u5165\u5f85\u8ddf\u65b0\u7684stock \u5217\u8868  \u5305\u542b symbol, name\n    :return:\n    &quot;&quot;&quot;\n    with my_sql(DataBase) as c:\n        for item in stockList:  # symbol date tscode\n            sqlStr = &quot;UPDATE stock_date SET train_date = trade_date WHERE symbol=&#039;{}&#039;&quot;.format(item[0])\n            c.execute(sqlStr)\n\ndef getPathList(symbolList):\n    &quot;&quot;&quot;\n    \u83b7\u53d6stock \u4fe1\u606f\u7684\u6587\u4ef6\u4f4d\u7f6e\u5217\u8868\n    :param symbolList: \u8f93\u5165\u5217\u8868\u5305\u542b\u53ef symbol\u548c name \u4fe1\u606f\n    :return: pathList  \u8f93\u51fa\u5217\u8868 \u5305\u542b path \u548c symbol\n    &quot;&quot;&quot;\n    pathList=[]\n    with my_sql(DataBase) as c:\n        for item in symbolList: #symbol :(symbol,name)\n            sqlStr = &quot;SELECT path FROM stock_path WHERE symbol =&#039;{}&#039;&quot;.format(item[0])\n            cursor = c.execute(sqlStr)\n            path = cursor.fetchone()[0]\n            pathList.append((path,item[0]))\n    return pathList #(path,symbol)\n\ndef getFilePath(symbol):\n    sqlStr = &quot;SELECT path FROM stock_path WHERE symbol = \\&quot;&quot; + symbol + &quot;\\&quot;&quot;\n    with my_sql(DataBase) as c:\n        cursor = c.execute(sqlStr)\n        path = cursor.fetchone()[0]\n        return path\n\ndef getLastTradeList(pathList):\n    &quot;&quot;&quot;\n    \u83b7\u53d6stock \u6700\u540e\u66f4\u65b0\u4fe1\u606f\n    :param pathList:\u8f93\u51fa\u5217\u8868 \u5305\u542b path \u548c symbol\n    :return:dateList \u8f93\u51fa\u5217\u8868 \u5305\u542b symbol date tscode)\n    &quot;&quot;&quot;\n    dateList =[] # (symbol date tscode)\n    for item in pathList: #item : (path, symbol)\n        df = pd.read_csv(item[0], dtype=object)\n        df.sort_values(by=[&quot;trade_date&quot;], inplace=True, ignore_index=True)\n        date = df.tail(1)[[&quot;trade_date&quot;, &quot;ts_code&quot;]].values[0]\n        dateList.append((item[1],date[0],date[1]))\n    return  dateList\n\ndef getLastDate(symbol):\n    &quot;&quot;&quot;\n    \u83b7\u53d6stock\u7684\u8bad\u7ec3\u6570\u636e\u7684\u6700\u540e\u65e5\u671f\uff1a\n    :param symbol: stock\u7684symbol\u4fe1\u606f\n    :return:  : trade_date,train_date\n    &quot;&quot;&quot;\n    columns = [&quot;symbol&quot;, &quot;trade_date&quot;, &quot;train_date&quot;]\n    sqlStr = &quot;SELECT {} From stock_date WHERE symbol= &#039;{}&#039;&quot;.format(columnsToSql(columns),symbol)\n\n    with my_sql(DataBase) as c:\n        cursor= c.execute(sqlStr)\n        data = cursor.fetchone()\n        return  data[1],data[2]\n\ndef getLastTradeDate(symbol):\n    &quot;&quot;&quot;\n    \u83b7\u53d6stock\u7684\u8bad\u7ec3\u6570\u636e\u7684\u6700\u540e\u65e5\u671f\uff1a\n    :param symbol: stock\u7684symbol\u4fe1\u606f\n    :return:  nparray: [date,tscode]\n    &quot;&quot;&quot;\n    path =getFilePath(symbol)\n    df =pd.read_csv(path,dtype=object)\n    df.sort_values(by=[&quot;trade_date&quot;],inplace=True,ignore_index=True)\n    date = df.tail(1)[[&quot;trade_date&quot;,&quot;ts_code&quot;]]\n    return date.values[0]\n\ndef getTradeDateList():\n    sqlStr= &quot;SELECT symbol, name, ts_code,trade_date FROM stock_date&quot;\n    conn =sqlite3.connect(DataBase)\n    df= pd.read_sql(sqlStr,con=conn)\n    conn.close();\n    return df\n\ndef getStockFrame(headers = [&quot;symbol&quot;,&quot;name&quot;]):\n    sqlStr = &quot;SELECT {} FROM stock_info&quot;.format(columnsToSql(headers))\n    conn = sqlite3.connect(DataBase)\n    df = pd.read_sql(sqlStr, con=conn, columns=headers)\n    conn.close();\n    return df\n\ndef getStockInfo(item):\n    if not item:\n        return pd.DataFrame()\n\n    columns = [&quot;symbol&quot;, &quot;area&quot;, &quot;industry&quot;, &quot;market&quot;, &quot;list_date&quot;, &quot;is_hs&quot;]\n    sqlStr = &quot;SELECT {} From stock_info WHERE symbol= &#039;{}&#039;&quot;.format(columnsToSql(columns),item[&quot;symbol&quot;])\n    conn = sqlite3.connect(DataBase)\n    df = pd.read_sql(sqlStr, con=conn)\n    conn.close()\n    df.fillna(&quot;&quot;, inplace=True)\n\n    if (df.values.shape[0] != 1):\n        return pd.DataFrame()\n    for index, key, in enumerate(columns):\n        name, value = parseStockInfo(key, df.values[0][index])\n        print(name, value)\n        df.values[0][index] = name + &quot;: &quot; + value\n    return df\n\ndef getDailyFrame(symbol:str)-&gt;pd.DataFrame:\n    path = getFilePath(symbol)\n    df = getStockData(path)\n    # 2.\u5c06nan\u6570\u636e \u66ff\u6362\u4e3a0\n    df.fillna(0, inplace=True)\n    # 3.\u6309\u7167\u65e5\u671f\u6392\u5e8f ,\u5e76\u5ffd\u7565\u7d22\u5f15\n    df.sort_values(by=[&quot;trade_date&quot;], ascending=True, ignore_index=True, inplace=True)\n    return df\n\ndef getStockData(path):\n    data = pd.read_csv(path, index_col=False, na_values=0)\n    # 1.\u6570\u636e\u53bb\u91cd\n    data.drop_duplicates(subset=[&quot;trade_date&quot;], keep=&#039;first&#039;, inplace=True)\n    return data\n\ndef getRecordData(symbol):\n    if not symbol :\n        return None\n    now_date = datetime.now().strftime(&quot;%Y%m%d&quot;)\n    sqlStr = &quot;SELECT real_data, predict_data, symbol FROM stock_record&quot; \\\n             &quot; WHERE symbol=&#039;{}&#039; AND time=&#039;{}&#039;&quot;.format(symbol,now_date)\n    with my_sql(DataBase) as c:\n        cursor = c.execute(sqlStr)\n        data = cursor.fetchone()\n        if data :\n            real, predict = pickle.loads(data[0]),pickle.loads(data[1])\n            return (real,predict)\n\ndef SaveRecordData(stockInfo, realData, predictData):\n    reald, pred = pickle.dumps(realData), pickle.dumps(predictData)\n    now_date = datetime.now().strftime(&quot;%Y%m%d&quot;)\n\n    sqlStr = &quot;REPLACE INTO stock_record (symbol,name,time,real_data, predict_data) VALUES (?,?,?,?,?) &quot;\n    with my_sql(DataBase) as c:\n        c.execute(sqlStr, (stockInfo[&quot;symbol&quot;], &quot;&quot;, now_date, reald, pred))\n    #print(now_date)\n    pass\n\ndef ClearRecordData():\n    sqlStr =&quot;DROP TABLE stock_record&quot;\n    with my_sql(DataBase) as c:\n        c.execute(sqlStr)\n\ndef parseStockInfo(key, value):\n    name = d[key]\n    if key == &#039;market&#039;:\n        name = &quot;\u5e02\u573a\u7c7b\u578b&quot;\n    elif key == &quot;curr_type&quot;:\n        if value == &quot;CNY&quot;:\n            value = &quot;RMB&quot;\n    elif key == &quot;list_status&quot;:\n        name = &quot;\u4e0a\u5e02\u72b6\u6001&quot;\n        if value == &quot;L&quot;:\n            value = &quot;\u4e0a\u5e02&quot;\n        elif value == &quot;D&quot;:\n            value = &quot;\u9000\u5e02&quot;\n        elif value == &quot;P&quot;:\n            value = &quot;\u6682\u505c\u4e0a\u5e02&quot;\n    elif key == &quot;is_hs&quot;:\n        name = &quot;\u6caa\u6df1\u6e2f\u901a&quot;\n        if value == &quot;S&quot;:\n            value = &quot;\u6df1\u6e2f\u901a&quot;\n        elif value == &quot;H&quot;:\n            value = &quot;\u6caa\u6e2f\u901a&quot;\n        elif value == &quot;N&quot;:\n            value = &quot;\u5426&quot;\n    return name, value\n\ndef columnsToSql(columns):\n    &quot;&quot;&quot;\n    \u628a\u5217\u8868\u8f6c\u6362\u6210\u5bf9\u5e94\u7684sql \u5217\u5b57\u7b26\u4e32\n    :param columns:\n    :return:\n    &quot;&quot;&quot;\n    columnStr = &quot;&quot;\n    for x in columns:\n        columnStr += x + &quot;,&quot;\n    return  columnStr[:-1]\n\ndef arr_split(arr,size):\n    s = [arr[i:i+size]  for i in range(0,int(len(arr))+1,size)]\n    return s\n\ndef DataBaseInit():\n    # CreateStockTabel()\n    saveStockData()\n\ndef sqlTest():\n    con = sqlite3.connect(&quot;data\\my_stock.db&quot;)\n    sql = &#039;select * from user_information LIMIT 3&#039;\n    df = pd.read_sql(sql, con)\n    print(df)\n\nif __name__ == &#039;__main__&#039;:\n    # DataBaseInit()\n    # getStockList()\n    # saveStockPath()\n    initTrainDate()\n    #getFilePath(&quot;000001&quot;)\n<\/code><\/pre>\n<h1>prepare.py<\/h1>\n<pre><code class=\"language-python\"># -*-coding:utf-8 -*-\nimport sys\n#print(sys.executable)\nimport numpy as np\nimport pandas as pd\nfrom sklearn.preprocessing import MinMaxScaler\nimport  stock_sql\nfrom abc import ABCMeta, abstractmethod\nfrom multiprocessing import cpu_count\nfrom concurrent.futures import ProcessPoolExecutor,as_completed\nimport time\nimport get_info\nimport  datetime\n\n#pd.show_versions()\ndaily = {\n    &quot;ts_code&quot;: &quot;\u80a1\u7968\u4ee3\u7801&quot;,\n    &quot;trade_date&quot;: &quot;\u4ea4\u6613\u65e5\u671f&quot;,\n    &quot;open&quot;: &quot;\u5f00\u76d8\u4ef7&quot;,\n    &quot;high&quot;: &quot;\u6700\u9ad8\u4ef7&quot;,\n    &quot;low&quot;: &quot;\u6700\u4f4e\u4ef7&quot;,\n    &quot;close&quot;: &quot;\u6536\u76d8\u4ef7&quot;,\n    &quot;pre_close&quot;: &quot;\u6628\u6536\u4ef7&quot;,\n    &quot;change&quot;: &quot;\u6da8\u8dcc\u989d&quot;,\n    &quot;pct_chg&quot;: &quot;\u6da8\u8dcc\u5e45(\u672a\u590d\u6743)&quot;,\n    &quot;vol&quot;: &quot;\u6210\u4ea4\u91cf(\u624b)&quot;,\n    &quot;amount&quot;: &quot;\u6210\u4ea4\u989d(\u5343\u5143)&quot;,\n}\n\nclass stockInfo:\n    def __init__(self,symbol:str =&quot;000001&quot;,name:str=&quot;\u5e73\u5b89\u94f6\u884c&quot;):\n        self.symbol=symbol\n        self.name=name\n\nSEQUENCE_LEN=20\nRREDICT_LEN=30\n\ndef parseUntrainedData(symbol,df):\n    &quot;&quot;&quot;\n    \u89e3\u6790\u672a\u8bad\u7ec3\u7528\u7684\u6570\u636e\n    :return:\n    &quot;&quot;&quot;\n    index =0\n    dateInfo =stock_sql.getLastDate(symbol)\n    lastDate = int(dateInfo[1])\n    for i,row in df.iterrows():\n        if row[&quot;trade_date&quot;] &gt;= lastDate:\n            index  = i\n            #df= df.iloc[:index+ SEQUENCE_LEN]\n            break\n    if index &lt; SEQUENCE_LEN -1 : #\u6570\u636e\u4ece\u672a\u8bad\u7ec3\u8fc7 \u5e94\u8be5\u4ece0\u5f00\u59cb\u8bad\u7ec3\n        index =0\n    elif SEQUENCE_LEN -1 &lt;= index: # \u6570\u636e\u8bad\u7ec3\u8fc7\uff1a\n        index = index+1  -(SEQUENCE_LEN -1)\n    return index\n\ndef getTrianColnum(stockinfo:stockInfo,length=RREDICT_LEN):\n    symbol = stockinfo.symbol\n    df = stock_sql.getDailyFrame(symbol)\n    # \u5bf9\u6570\u636e\u8fdb\u884c\u8bad\u4e86\u6d4b\u8bd5\u7684\u5206\u79bb\u5904\u7406\n    x, y = prepareLogistics().dataProcess(df)\n    col =[&quot;close&quot;]\n    close = np.array(x[col])[-length:]\n    return  np.array( close)\n\ndef columnSplit(verify=False):\n    &quot;&quot;&quot;\n    \u5c06\u6570\u636e\u6309\u7167\u5217\u7684\u65b9\u5f0f\u62c6\u5206\u6210 x\u548cy\n    verify :\u9a8c\u8bc1y \u7684\u65e5\u671f\u662f\u5426\u6b63\u786e \u5f00\u542f\u65f6 \u5c06\u5e26\u4e0ay\u5bf9\u5e94\u7684\u65e5\u671f\n    :return:\n    &quot;&quot;&quot;\n    all = list(daily.keys());\n    column_y=[&#039;ts_code&#039;]\n    column_x = [x for x in all if x not in column_y]\n    if verify:# \u5e26\u4e0a\u65e5\u671f \u9a8c\u8bc1 \u65e5\u671f\u662f\u5426\u5bf9\u5e94\n        column_y = [&quot;close&quot;,&#039;trade_date&#039;]\n    else:\n        column_y = [&quot;close&quot;]\n    return  column_x ,column_y\n\ndef getPredictTrainnData( stockInfo,length =RREDICT_LEN):\n    #\u83b7\u53d6stock \u539f\u59cb\u4fe1\u606f \uff0c\u53bb\u91cd \u3001\u9884\u6d4b\u6570\u636e\u5bf9\u5176\n    symbol = stockInfo[&quot;symbol&quot;]\n    df = stock_sql.getDailyFrame(symbol)\n    column_x, column_y = columnSplit()\n    x = np.array(df[column_x])\n    x= np.around(x,decimals=2)\n    return  x[-length:]\n\ndef dataSplit(npArr, split=0.9):\n    &quot;&quot;&quot;\n    :param npArr:\n    :param split:\n    :return:\n    &quot;&quot;&quot;\n    #\u5bf9\u6570\u636e\u8fdb\u884c\u62c6\u5206 \u83b7\u5f97\u9a8c\u8bc1\u6570\u636e\u96c6\n    split_boundary = int(npArr.shape[0] * split)\n    train_x = npArr[: split_boundary]\n    test_x = npArr[split_boundary:]\n    return train_x, test_x\n\nclass prepareBase:\n    def __init__(self):\n        print(str( type(self) )+&quot; __build __&quot;)\n        self.verify = False\n        self.only_untrain = False\n        self.batch_size = 200\n        self.SEQUENCE_LEN = 50\n        self.RREDICT_LEN=30\n        self.tuple_x=1\n        pass\n\n    def getInputShape(self):\n        &quot;&quot;&quot;\n        \u83b7\u53d6\u8bad\u7ec3\u6570\u636e\u7684shape\n        :return:  shape\n        &quot;&quot;&quot;\n        column_x, _ = columnSplit()\n        return (self.SEQUENCE_LEN, len(column_x))\n\n    def dataProcess(self,df:pd.DataFrame)-&gt;(np.array,np.array):\n        &quot;&quot;&quot;\n        :param df:\n        :return:\n        &quot;&quot;&quot;\n        # 4.\u5bf9\u6570\u636e\u6309\u7167\u8868\u683c\u62c6\u5206\n        column_x, column_y = columnSplit(self.verify)\n        x, y = df[column_x], df[column_y]\n        # 5\u62fc\u51d1\u6570\u636e x\u7684\u6700\u540e\u4e00\u884c\u6ca1\u6709\u9884\u6d4b\u503c y\u7684\u7b2c\u4e00\u884c\u6ca1\u6709 \u8bad\u7ec3\u503c\n        x, y = x.drop(len(x) - 1, axis=0), y.drop(0, axis=0)\n        y.reset_index(drop=True, inplace=True)\n        return x, y\n\n    def dataSequence(self,x:pd.DataFrame,nor:bool=True)-&gt;np.array:\n        &quot;&quot;&quot;\n        \u5c06dataFrame\u6570\u636e \u7ec4\u6210\u8bad\u7ec3\u5e8f\u5217\n        :param df: \u539f\u59cbdataFrame\u6570\u636e\n        :param nor: \u662f\u5426\u5bf9\u6570\u636e\u6807\u51c6\u5316\n        :param len: \u5e8f\u5217\u957f\u5ea6\n        :return: \u5e8f\u5217\u5316\u6570\u636e\u4ee5\u53ca \u6807\u51c6\u5316\u7684Scaler\n        &quot;&quot;&quot;\n        scaler = MinMaxScaler()\n        if nor :\n            data_all = np.array(x).astype(&quot;float64&quot;)\n            data_all = scaler.fit_transform(data_all)\n        else:\n            data_all = np.array(x)\n        data = []\n        for i in range(len(data_all) - self.SEQUENCE_LEN  + 1):\n            data.append(data_all[i: i +  self.SEQUENCE_LEN])\n        x = np.array(data).astype(&#039;float64&#039;)\n\n        return x, scaler\n\n    def dataSequence_y(self,y:pd.DataFrame,nor:bool=True)-&gt;np.array:\n        &quot;&quot;&quot;\n        \u5c06dataFrame\u6570\u636e \u7ec4\u6210\u8bad\u7ec3\u5e8f\u5217\n        :param df: \u539f\u59cbdataFrame\u6570\u636e\n        :param nor: \u662f\u5426\u5bf9\u6570\u636e\u6807\u51c6\u5316\n        :param len: \u5e8f\u5217\u957f\u5ea6\n        :return: \u5e8f\u5217\u5316\u6570\u636e\u4ee5\u53ca \u6807\u51c6\u5316\u7684Scaler\n        &quot;&quot;&quot;\n        scaler = MinMaxScaler()\n        if nor :\n            data_all = np.array(y).astype(&quot;float64&quot;)\n            data_all = scaler.fit_transform(data_all)\n        else:\n            data_all = np.array(y)\n        data = []\n        for i in range(len(data_all) - self.SEQUENCE_LEN + 1):\n            data.append(data_all[i + self.SEQUENCE_LEN - 1])\n        x = np.array(data)\n\n        return x, scaler\n\n    def test(self):\n        print(&quot;+++++++++++++++++++++++&quot;)\n        self.verify = True\n        self.only_untrain = True\n        df = stock_sql.getStockFrame()\n        df.sort_values(by=[&quot;symbol&quot;], inplace=True, ignore_index=True)\n        stock_list = np.array(df)\n        for info in stock_list:\n            data = self.getTrainData(stockInfo(symbol=info[0]))\n            data_x, data_y = data[0], data[1]\n            for i in range(len(data_x)):\n                print(data_x[i][self.SEQUENCE_LEN - 1])\n                print(data_y[i])\n\n    def getTrainData(self, stockinfo: stockInfo) -&gt; (np.array, np.array):\n        &quot;&quot;&quot;\n           \u83b7\u53d6\u5355\u53ea\u80a1\u7968\u7684\u8bad\u7ec3\u6570\u636e\n           :param stockInfo: stock \u4fe1\u606f\n           :param only_untrain: \u53ea\u4f7f\u7528\u672a\u8bad\u7ec3\u7684\u6570\u636e\n           :return: \u5982\u679c\u6570\u636e\u4e0d\u591f\u8fd4\u56deNone\n           &quot;&quot;&quot;\n        symbol = stockinfo.symbol\n        df = stock_sql.getDailyFrame(symbol)\n        if len(df) &lt;= self.SEQUENCE_LEN:\n            return None\n        # \u5bf9\u6570\u636e\u8fdb\u884c\u8bad\u4e86\u6d4b\u8bd5\u7684\u5206\u79bb\u5904\u7406\n        x, y = self.dataProcess(df)\n        train_x, scaler_x = self.dataSequence(x, nor= not self.verify)\n        train_y, scaler_y = self.dataSequence_y(y,nor =not self.verify)\n        if self.only_untrain:\n            index = parseUntrainedData(symbol, df)\n            if index &gt;= len(train_x):\n                # \u8d77\u59cb\u8bad\u7ec3\u6570\u636e\u8d85\u51fa\u662f\u957f\u5ea6,\u6ca1\u6709\u6570\u636e\n                return None\n            train_x, train_y = train_x[index:], train_y[index:]\n        return train_x, train_y,scaler_x,scaler_y\n\n    def getTestData(self, stockInfo:stockInfo):\n        data =self.getTrainData(stockInfo)\n        if data is not None and len(data[0])&gt;=0:\n            train_x, train_y =data[0],data[1]\n            return  train_x[-self.RREDICT_LEN:], train_y[-self.RREDICT_LEN:],data[2],data[3]\n        else:\n            return None\n\n    def dataGenerator(self, index: int = 0):\n        df = stock_sql.getStockFrame()\n        df.sort_values(by=[&quot;symbol&quot;], inplace=True, ignore_index=True)\n        # \u622a\u53d6index\u4ee5\u540e\u7684\u6570\u636e\n        stock_array = np.array(df)[index:]\n        batch_array = stock_sql.arr_split(stock_array,  self.batch_size)\n\n        for s_array in batch_array:\n            # \u591a\u8fdb\u7a0b\n            list_x, list_y =  self.dataConcat(s_array)\n            # list_x,list_y =dataConcat(s_array,only_untrain)\n            if len(list_x[0]) == 0 or len(list_y) == 0: continue\n            data_x = []\n            for i in range(self.tuple_x):\n                data_x.append( np.concatenate(list_x[i], axis=0))\n            data_y =  np.concatenate(list_y, axis=0)\n            yield data_x, data_y, s_array\n            # yield (getTrainData({&quot;symbol&quot;:info[0]},only_untrain),info)\n\n    def dataConcat(self,stock_list):\n        &quot;&quot;&quot;\n           \u6279\u91cf\u5904\u7406stock list \u6570\u636e  \u5c06\u6bcf\u53ea\u5904\u7406\u7684stock\u6570\u636e\u653e\u5230 listx \u548c listy\u4e2d\n           :param stock_list:  \u5f85\u5904\u7406\u7684stock list \u5305\u542bsymbol name \u4fe1\u606f\n           :param only_untrain:  \u53ea\u5904\u7406\u672a\u4f7f\u7528\u7684\u6570\u636e\n           :return: list_x,list_x\n        &quot;&quot;&quot;\n        list_x, list_y = [], []\n        for i in range(self.tuple_x):\n            list_x.append([])\n\n        for info in stock_list:\n            data =self.getTrainData(stockInfo(symbol=info[0]))\n            print(&quot;stock info :&quot; + info + &quot;\\r\\n&quot;)\n            if data == None: continue\n            for i in  range(self.tuple_x):\n                list_x[i].append(data[0][i])\n            list_y.append(data[1])\n        return list_x, list_y\n\n    def dataConcatMultiple(self, stock_array, only_untrain=False):\n        list_x, list_y = [], []\n        for i in range(self.tuple_x):\n            list_x.append([])\n        split_array = np.array_split(stock_array, 4, axis=0)\n        executor = ProcessPoolExecutor(max_workers=4)\n\n        all_task = [executor.submit(self.dataConcat, list_a) for list_a in split_array]\n        for future in as_completed(all_task):\n            res = future.result()\n            for i in range(self.tuple_x):\n                list_x[i].extend(res[i][0])\n            list_y.extend(res[1])\n        return list_x, list_y\n\nclass logisticsAllScaler(prepareBase):\n    def __init__(self):\n        super(logisticsAllScaler,self).__init__()\n        self.SEQUENCE_LEN=50\n        self.list_x= None\n        time.sleep(10)\n        pass\n\n    def getDataWithIndex(self, stock_arr, only_untrain=False):\n        list_x, list_y, index_list = [], [], []\n        trian_map = {}\n        index = 0\n        for item in stock_arr:\n            print(item[1])\n            df = stock_sql.getDailyFrame(item[0])\n            if only_untrain:\n                index = parseUntrainedData(item[0], df)\n            # df.drop(columns=[&quot;ts_code&quot;],inplace=True)\n            x, y = self.dataProcess(df)\n            trian_map[item[0]] = index\n            list_x.append(x)\n            list_y.append(y)\n            index_list.append(len(x))\n        return list_x, list_y, index_list, trian_map\n\n    def getDataWithIndexMultiple(self,stock_arr, only_untrain=False):\n        list_x, list_y, index_list = [], [], []\n        trian_map = {}\n        a_split = np.array_split(stock_arr, 4, axis=0)\n        executor = ProcessPoolExecutor(max_workers=4)\n        all_tasks = [executor.submit(self.getDataWithIndex, stockS_arr, only_untrain) for stockS_arr in a_split]\n        for furture in as_completed(all_tasks):\n            res = furture.result();\n            list_xt, list_yt, index_listt, trian_mapt = res[0], res[1], res[2], res[3]\n            list_x.extend(list_xt)\n            list_y.extend(list_yt)\n            index_list.extend(index_listt)\n            trian_map.update(trian_mapt)\n        return list_x, list_y, index_list, trian_map\n\n    def allScaler(self):\n        if self.list_x is not None:\n            return self.list_x, self.list_y, self.scaler_x, self.scaler_y, self.train_map\n\n        df = stock_sql.getStockFrame()\n        df.sort_values(by=[&quot;symbol&quot;], inplace=True, ignore_index=True)\n        arr = np.array(df)\n        data_new_x, data_new_y, list_index, self.train_map = self.getDataWithIndexMultiple(arr, self.only_untrain)\n        data_new_x = np.concatenate(data_new_x, axis=0)\n        data_new_y = np.concatenate(data_new_y, axis=0)\n        self.scaler_x, self.scaler_y = MinMaxScaler(), MinMaxScaler()\n        data_new_x, data_new_y = self.scaler_x.fit_transform(data_new_x), self.scaler_y.fit_transform(data_new_y)\n        self.list_x, self.list_y = {}, {}\n        list_symbols = list(self.train_map.keys())\n        for i in range(len(list_index)):\n            data_a, data_b = data_new_x[:list_index[i]], data_new_y[:list_index[i]]\n            data_new_x, data_new_y = data_new_x[list_index[i]:], data_new_y[list_index[i]:]\n            self.list_x[list_symbols[i]] = data_a\n            self.list_y[list_symbols[i]] = data_b\n\n        return self.list_x, self.list_y, self.scaler_x, self.scaler_y, self.train_map\n\n    def getTrainData(self, stockinfo: stockInfo) -&gt; (np.array, np.array):\n        self.allScaler()\n        symbol= stockinfo.symbol\n        x, y = self.list_x[symbol], self.list_y[symbol]\n        train_x, _ = self.dataSequence(x, nor=False)\n        train_y, _ = self.dataSequence_y(y, nor=False)\n        if self.only_untrain:\n            index = self. trian_map[symbol]\n            if index &gt;= len(train_x):  # \u8d77\u59cb\u8bad\u7ec3\u6570\u636e\u8d85\u51fa\u662f\u957f\u5ea6,\u6ca1\u6709\u6570\u636e\n                return None\n            train_x, train_y = train_x[index:], train_y[index:]\n        return train_x,train_y,self.scaler_x,self.scaler_y\n\nclass prepareLogistics(prepareBase):\n    def __init__(self):\n        super(prepareLogistics,self).__init__()\n        self.SEQUENCE_LEN=50\n        time.sleep(10)\n        pass\n\nclass prepareClassify(prepareBase):\n\n    def __init__(self):\n        super(prepareClassify,self).__init__()\n        time.sleep(10)\n        pass\n\n    def dataProcess(self,df:pd.DataFrame)-&gt; (np.array,np.array):\n        &quot;&quot;&quot;\n        :param df: \u5305\u542bstock\u4fe1\u606f\u7684 dataframe\n        :return:\n        &quot;&quot;&quot;\n        # 4.\u5bf9\u6570\u636e\u6309\u7167\u8868\u683c\u62c6\u5206\n        column_x, column_y = columnSplit(self.verify)\n        x, y_t = np.array( df[column_x]),np.array( df[column_y])\n        # 5\u62fc\u51d1\u6570\u636e x\u7684\u6700\u540e\u4e00\u884c\u6ca1\u6709\u9884\u6d4b\u503c y\u7684\u7b2c\u4e00\u884c\u6ca1\u6709 \u8bad\u7ec3\u503c\n        y_value= [ y_t[i+1][0]-y_t[i][0] for i in range(len(y_t)-1)]\n        y_value= np.int32( np.array( y_value)&gt; 0).reshape(-1, 1)\n        if self.verify:\n            y_date= [str(y_t[i][1]) + &quot;-&gt;&quot; + str(y_t[i + 1][1]) + &quot; :&quot; + str(y_t[i + 1][0]) + &quot;-&quot; + str(y_t[i][0]) for i in range(len(y_t)-1)]\n            y_date  = np.array(y_date).reshape(-1, 1)\n            y_value = np.concatenate(( y_date ,y_value),axis=1)\n        return x[: len(x)-1], y_value\n\n    def dataSequence_y(self,y:pd.DataFrame,nor :bool =False )-&gt; np.array:\n        data_all = np.array(y)\n        data = []\n        for i in range(len(data_all) -self.SEQUENCE_LEN  + 1):\n            data.append(data_all[i + self.SEQUENCE_LEN - 1])\n        return np.array(data) ,None\n\nclass prepareLogistics_Ex(prepareBase):\n    def __init__(self):\n        super(prepareLogistics_Ex,self).__init__()\n        self.SEQUENCE_LEN=50\n        self.tuple_x=3\n        stock_base = get_info.stock_basic()\n        self.area_vec,self.indu_vec ,self.area_shape,self.indu_shape= stock_base.getExVec()\n        time.sleep(10)\n        pass\n\n    def getInputShape(self):\n        daily_shape = super().getInputShape()\n        area_shape = (self.area_shape[1],)\n        indu_shape =  (self.indu_shape[1],)\n        return daily_shape,area_shape,indu_shape\n\n    def getTrainData(self, stockinfo: stockInfo) -&gt; (np.array, np.array):\n        data= super().getTrainData(stockinfo)\n        if data  is None:\n            return  None\n        train_x, train_y, self.scaler_x, self.scaler_y = data\n        sample_len = train_x.shape[0]\n        ex_area ,ex_indu=self.area_vec[stockinfo.symbol],self.indu_vec[stockinfo.symbol]\n        ex_area,ex_indu = np.tile(ex_area,(sample_len,1)), np.tile(ex_indu,(sample_len,1))\n\n        train_x =train_x,ex_area,ex_indu\n\n        return train_x,train_y,self.scaler_x,self.scaler_y\n\nif __name__ == &quot;__main__&quot;:\n    # getTrainData(1)\n    # allScaler()\n    # for data in dataGenerator(0,True):\n    #     print(&quot;data&quot;)\n    pre =prepareClassify()\n    pre.test()\n<\/code><\/pre>\n<h1>trainning.py<\/h1>\n<pre><code class=\"language-python\">import sys\n# print(sys.executable)\nimport numpy as np\nimport matplotlib.pyplot as plt\nimport gc\n\nimport tensorflow as tf\nfrom tensorflow.keras.callbacks import TensorBoard\nfrom tensorflow.keras.callbacks import ModelCheckpoint\nimport datetime\nimport tensorflow.keras as keras\nimport prepare\n\nimport stock_sql\nfrom abc import ABCMeta, abstractmethod\nimport time\n\nimport pickle\n\n# gpus = tf.config.experimental.list_physical_devices(device_type=&#039;GPU&#039;)\n# cpus = tf.config.experimental.list_physical_devices(device_type=&#039;CPU&#039;)\n# print(gpus)\n# print(cpus)\n#\n# os.environ[&quot;CUDA_VISIBLE_DEVICES&quot;] = &quot;-1&quot;\n# tf.config.experimental.set_visible_devices(devices=gpus[0], device_type=&#039;GPU&#039;)\n# tf.config.experimental.set_visible_devices(devices=cpus[0], device_type=&#039;CPU&#039;)\n# tf.config.experimental.set_virtual_device_configuration(\n#     gpus[1],\n#     [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=1536)]\n# )\n\nweight_file = &quot;&quot;.join([&quot;model\/weights\/my_weight&quot;])\nweight_file_1 = &quot;&quot;.join([&quot;model\/weights\/my_weight_1&quot;])\nindex_file = &quot;index.txt&quot;\nlog_dir = &quot;&quot;.join(\n    [&quot;log\\\\model_train\\\\&quot;, datetime.datetime.now().strftime(&quot;%Y%m%d-%H%M%S&quot;)])\nmodel = None\n\ndef getTrainIndex():\n    with open(index_file, &quot;r&quot;) as f:\n        index = f.readline()\n    return int(index)\n\ndef saveTrainIndex(index):\n    with open(index_file, &quot;w&quot;) as f:\n        f.write(str(index))\n        f.flush()\n\ndef build_model(shape):\n    model = model_1(shape)\n    keras.utils.plot_model(model, &#039;picture\/multi_model.png&#039;, show_shapes=True)\n    return model\n\ndef model_2(shape):\n    model = keras.Sequential()\n    model.add(keras.Input(shape=shape))\n    model.add(keras.layers.LSTM(\n        units=50, activation=&#039;tanh&#039;, return_sequences=True))\n    model.add(keras.layers.LSTM(\n        units=50, activation=&#039;tanh&#039;, return_sequences=True))\n    model.add(keras.layers.Dense(units=10, activation=&quot;tanh&quot;))\n    model.add(keras.layers.Dense(units=1))\n    model.compile(optimizer=keras.optimizers.Adam(),\n                  loss=&quot;mse&quot;,\n                  metrics=[keras.metrics.mae])\n    model.summary()\n    return model\n\ndef testing(stockInfo):\n    &quot;&quot;&quot;\n\n    :param stockInfo:{symbol: XXX,name}\n    :return:\n    &quot;&quot;&quot;\n    return logisticsTrain().predict(stockInfo)\n\ndef showData(test_y, predict_y):\n    x = [i for i in range(len(predict_y))]\n    plt.plot(x, predict_y, &quot;b*--&quot;)\n    plt.plot(x, test_y, &quot;gv--&quot;)\n    # plt.plot(test_y)\n    plt.show()\n\ndef template():\n    # \u6784\u5efa\u4e00\u4e2a\u6839\u636e\u6587\u6863\u5185\u5bb9\u3001\u6807\u7b7e\u548c\u6807\u9898\uff0c\u9884\u6d4b\u6587\u6863\u4f18\u5148\u7ea7\u548c\u6267\u884c\u90e8\u95e8\u7684\u7f51\u7edc\n    # \u8d85\u53c2\n    num_words = 2000\n    num_tags = 12\n    num_departments = 4\n\n    # \u8f93\u5165\n    body_input = keras.Input(shape=(None,), name=&#039;body&#039;)\n    title_input = keras.Input(shape=(None,), name=&#039;title&#039;)\n    tag_input = keras.Input(shape=(num_tags,), name=&#039;tag&#039;)\n\n    # \u5d4c\u5165\u5c42\n    body_feat = keras.layers.Embedding(num_words, 64)(body_input)\n    title_feat = keras.layers.Embedding(num_words, 64)(title_input)\n\n    # \u7279\u5f81\u63d0\u53d6\u5c42\n    body_feat = keras.layers.LSTM(32)(body_feat)\n    title_feat = keras.layers.LSTM(128)(title_feat)\n    features = keras.layers.concatenate([title_feat, body_feat, tag_input])\n\n    # \u5206\u7c7b\u5c42\n    priority_pred = keras.layers.Dense(\n        1, activation=&#039;sigmoid&#039;, name=&#039;priority&#039;)(features)\n    department_pred = keras.layers.Dense(\n        num_departments, activation=&#039;softmax&#039;, name=&#039;department&#039;)(features)\n\n    # \u6784\u5efa\u6a21\u578b\n    model = keras.Model(inputs=[body_input, title_input, tag_input],\n                        outputs=[priority_pred, department_pred])\n    # model.summary()\n    keras.utils.plot_model(\n        model, &#039;picture\/template_model.png&#039;, show_shapes=True)\n\ndef logisticsMode(shape):\n    model = keras.Sequential()\n    model.add(keras.Input(shape=shape))\n    #model.add(keras.layers.Dense(units=100, activation=&quot;tanh&quot;))\n    model.add(keras.layers.LSTM(units=500, activation=&#039;tanh&#039;, return_sequences=True))\n    model.add(keras.layers.LSTM(units=500, activation=&#039;tanh&#039;, return_sequences=True))\n    model.add(keras.layers.LSTM(units=200, activation=&#039;tanh&#039;, return_sequences=True))\n    model.add(keras.layers.LSTM(units=200, activation=&#039;tanh&#039;, return_sequences=False))\n    model.add(keras.layers.Dense(units=200, activation=&quot;tanh&quot;))\n    model.add(keras.layers.Dense(units=20, activation=&quot;tanh&quot;))\n    model.add(keras.layers.Dense(units=1))\n    model.compile(optimizer=keras.optimizers.Adam(),\n                  loss=&quot;mse&quot;,\n                  metrics=[keras.metrics.mae])\n    model.summary()\n    return model\n\ndef classifyModel(shape):\n\n    inn = keras.Input(shape=shape)\n    lstm1 = keras.layers.LSTM(units=500, activation=&#039;tanh&#039;, return_sequences=True)(inn)\n    lstm2 = keras.layers.LSTM(units=500, activation=&#039;tanh&#039;, return_sequences=True)(lstm1)\n    lstm3 = keras.layers.LSTM(units=200, activation=&#039;tanh&#039;, return_sequences=True)(lstm2)\n    lstm4 = keras.layers.LSTM(units=50, activation=&#039;tanh&#039;, return_sequences=True)(lstm3)\n    flatten = keras.layers.Flatten()(lstm4)\n    Dense1 = keras.layers.Dense(units=200, activation=&quot;relu&quot;)(flatten)\n    ott = keras.layers.Dense(units=3)(Dense1)\n\n    model = keras.Model(inputs=inn, outputs=ott)\n    model.compile(optimizer=&#039;adam&#039;,\n                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n                  metrics=[&#039;accuracy&#039;])\n    model.summary()\n    # keras.utils.plot_model(model,&quot;picture\/classify_model.png&quot;,show_shapes=True)\n    return model\n\ndef classifyModel_1(shape):\n    inn = keras.Input(shape=shape)\n    lstm1 = keras.layers.LSTM(units=500, activation=&#039;tanh&#039;, return_sequences=True)(inn)\n    lstm2 = keras.layers.LSTM(units=500, activation=&#039;tanh&#039;, return_sequences=True)(lstm1)\n    lstm3 = keras.layers.LSTM(units=200, activation=&#039;tanh&#039;, return_sequences=True)(lstm2)\n    lstm4 = keras.layers.LSTM(units=200, activation=&#039;tanh&#039;, return_sequences=True)(lstm3)\n    lstm5 = keras.layers.LSTM(units=100, activation=&#039;tanh&#039;, return_sequences=True)(lstm4)\n    flatten = keras.layers.Flatten()(lstm5)\n    Dense1 = keras.layers.Dense(units=200, activation=&quot;relu&quot;)(flatten)\n    Dense2 = keras.layers.Dense(units=100,activation=&quot;relu&quot;)(Dense1)\n    ott =  keras.layers.Dense(units= 3)(Dense2)\n\n    model = keras.Model(inputs=inn, outputs=ott)\n    model.compile(optimizer=&#039;adam&#039;,\n                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n                  metrics=[&#039;accuracy&#039;])\n    model.summary()\n    # keras.utils.plot_model(model,&quot;picture\/classify_model.png&quot;,show_shapes=True)\n    return model\n\ndef logisticsModel_Ex(daily_shape,area_shape,indu_shape):\n    daily_input = keras.Input(shape= daily_shape)\n    area_input = keras.Input(shape= area_shape)\n    indu_input = keras.Input(shape= indu_shape)\n\n    lstm1 = keras.layers.LSTM(units=500, activation=&quot;relu&quot;, return_sequences=True)(daily_input)\n    lstm2 = keras.layers.LSTM(units=200, activation=&quot;relu&quot;, return_sequences=True)(lstm1)\n    lstm3 = keras.layers.LSTM(units=200, activation=&quot;relu&quot;, return_sequences=True)(lstm2)\n    lstm4 = keras.layers.LSTM(units=200, activation=&quot;relu&quot;, return_sequences=False)(lstm3)\n\n    # dense_area = keras.layers.Dense(units=200, activation=&quot;relu&quot;)(area_input)\n    # dense_indu = keras.layers.Dense(units=200, activation=&quot;relu&quot;)(indu_input)\n\n    flatten =keras.layers.concatenate([lstm4,area_input,indu_input])\n\n    dense1 = keras.layers.Dense(units=200, activation=&quot;relu&quot;)(flatten)\n    dense2 = keras.layers.Dense(units=100, activation =&quot;relu&quot;)(dense1)\n    dense3 = keras.layers.Dense(units=100, activation= &quot;relu&quot;)(dense2)\n    ott = keras.layers.Dense(units=1)(dense3)\n\n    model =keras.Model(inputs=[daily_input,area_input,indu_input],outputs=ott)\n\n    model.compile(optimizer=keras.optimizers.Adam(),\n                  loss=&quot;mse&quot;,\n                  metrics=[keras.metrics.mae])\n    model.summary()\n    return model\n\nclass trainBase:\n    def __init__(self):\n        self.dataPre:prepare.prepareBase =None\n        self.model:keras.Model=None\n        pass\n\n    @abstractmethod\n    def buildModel(self):\n        pass\n\n    def trainAll(self):\n        for data in self.dataPre.dataGenerator(0):\n            train_x, train_y, stock_list = data\n            print(train_x[0].shape)\n            print(train_x[1].shape)\n            print(train_x[2].shape)\n            print(train_y.shape)\n            history = self.model.fit(train_x, train_y, batch_size=self.batch_size,\n                                     epochs=self.epochs, validation_split=0.05)\n            self.model.save_weights(self.file_path)\n            stock_sql.updateTrianList(stock_list)\n            del train_x, train_y, stock_list\n            gc.collect()\n            gc.collect()\n\n    def predict(self, stockinfo: prepare.stockInfo) -&gt; (np.array, np.array):\n        data = stock_sql.getRecordData(stockinfo.symbol)\n        if data:\n            return data[0], data[1]\n        data = self.dataPre.getTestData(stockinfo)\n        if data is None:\n            return np.array([]), np.array([])\n        test_x, real_y,scaler_x, scaler_y = data[0], data[1], data[2], data[3]\n        predict_y = self.model.predict(test_x)\n        if scaler_y is not  None:\n            real_y = scaler_y.inverse_transform(real_y)\n            predict_y = scaler_y.inverse_transform(predict_y.astype(&quot;float64&quot;))\n        real_y, predict_y = np.around( real_y, decimals=2), np.around(predict_y, decimals=2)\n\n        stock_sql.SaveRecordData(stockinfo, real_y, predict_y)\n        return real_y, predict_y\n\nclass logisticsTrain(trainBase):\n    def __init__(self):\n        self.file_path = &quot;&quot;.join([&quot;model\/weights\/my_weight&quot;])\n        self.batch_size = 1024\n        self.epochs = 20\n        self.buildModel()\n\n    def buildModel(self):\n        print(&quot;parepareLogistics&quot;)\n        self.dataPre = prepare.logisticsAllScaler()\n\n        self.model = logisticsMode(self.dataPre.getInputShape())\n        try:\n            print(&quot;load_weights from &quot; + self.file_path)\n            self.model.load_weights(self.file_path)\n            print(&quot; success &quot;)\n        except:\n            print(&quot;load_weight failed&quot;)\n            pass\n\n        time.sleep(10)\n\nclass classifyTrain(trainBase):\n    def __init__(self):\n        self.file_path = &quot;&quot;.join([&quot;model\/weights\/classify_weight&quot;])\n        self.batch_size = 1024\n        self.epochs = 5\n        self.buildModel()\n\n    def buildModel(self):\n        self.dataPre = prepare.prepareClassify()\n        print(&quot;prepareClassify&quot;)\n        self.model = classifyModel(self.dataPre.getInputShape())\n        try:\n            print(&quot;load_weights from &quot; + self.file_path)\n            self.model.load_weights(self.file_path)\n            print(&quot; success &quot;)\n        except:\n            print(&quot;load_weight failed&quot;)\n            pass\n        time.sleep(10)\n\nclass classifyTrain_1(trainBase):\n    def __init__(self):\n        self.file_path = &quot;&quot;.join([&quot;model\/weights\/classify_weight_1&quot;])\n        self.batch_size = 1024\n        self.epochs = 5\n        self.buildModel()\n\n    def buildModel(self):\n        self.dataPre = prepare.prepareClassify()\n        print(&quot;prepareClassify&quot;)\n        self.model = classifyModel_1(self.dataPre.getInputShape())\n        try:\n            print(&quot;load_weights from &quot; + self.file_path)\n            self.model.load_weights(self.file_path)\n            print(&quot; success &quot;)\n        except:\n            print(&quot;load_weight failed&quot;)\n            pass\n        time.sleep(10)\n\nclass logisticsTrain_Ex(trainBase):\n\n    def __init__(self):\n        self.file_path =&quot;&quot;.join([&quot;model\/weights\/logisticsEx_weight&quot;])\n        self.batch_size = 512\n        self.epochs = 5\n        self.buildModel()\n\n    def buildModel(self):\n        self.dataPre = prepare.prepareLogistics_Ex()\n        print(&quot;prepareClassifyEx&quot;)\n        daily_shape, area_shape, indu_shape = self.dataPre.getInputShape()\n        self.model = logisticsModel_Ex(daily_shape, area_shape, indu_shape)\n        try:\n            print(&quot;load_weights from &quot; + self.file_path)\n            self.model.load_weights(self.file_path)\n            print(&quot; success &quot;)\n        except:\n            print(&quot;load_weight failed&quot;)\n            pass\n        time.sleep(10)\n        pass\n\nif __name__ == &#039;__main__&#039;:\n    # stock_sql.initTrainDate()\n\n    train = logisticsTrain_Ex()\n    train.dataPre.only_untrain=False\n    train.trainAll()\n<\/code><\/pre>\n","protected":false},"excerpt":{"rendered":"<p>\u7b80\u4ecb\uff1a \u5904\u7406\u8fc7\u7a0b\u5927\u4f53\u5982\u4e0b\uff1a 1\u3001\u4f7f\u7528tushare\u83b7\u53d6stock\u4fe1\u606f 2\u3001\u5bf9\u6570\u636e\u8fdb\u884c\u5904\u7406\uff0c\u505a\u597dtrain_x\u548c &hellip; <\/p>\n<p class=\"link-more\"><a href=\"https:\/\/www.nickchan.cn\/index.php\/2020\/02\/04\/%e8%82%a1%e7%a5%a8%e6%a8%a1%e5%9e%8b%e9%a2%84%e6%b5%8b\/\" class=\"more-link\">\u7ee7\u7eed\u9605\u8bfb<span class=\"screen-reader-text\">\u201c\u80a1\u7968\u6a21\u578b\u9884\u6d4b\u201d<\/span><\/a><\/p>\n","protected":false},"author":1,"featured_media":0,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"footnotes":""},"categories":[138],"tags":[],"_links":{"self":[{"href":"https:\/\/www.nickchan.cn\/index.php\/wp-json\/wp\/v2\/posts\/110"}],"collection":[{"href":"https:\/\/www.nickchan.cn\/index.php\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/www.nickchan.cn\/index.php\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/www.nickchan.cn\/index.php\/wp-json\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/www.nickchan.cn\/index.php\/wp-json\/wp\/v2\/comments?post=110"}],"version-history":[{"count":2,"href":"https:\/\/www.nickchan.cn\/index.php\/wp-json\/wp\/v2\/posts\/110\/revisions"}],"predecessor-version":[{"id":117,"href":"https:\/\/www.nickchan.cn\/index.php\/wp-json\/wp\/v2\/posts\/110\/revisions\/117"}],"wp:attachment":[{"href":"https:\/\/www.nickchan.cn\/index.php\/wp-json\/wp\/v2\/media?parent=110"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/www.nickchan.cn\/index.php\/wp-json\/wp\/v2\/categories?post=110"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/www.nickchan.cn\/index.php\/wp-json\/wp\/v2\/tags?post=110"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}