-
Notifications
You must be signed in to change notification settings - Fork 402
/
Copy pathrun_predict.py
195 lines (171 loc) · 8.25 KB
/
run_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
# -*- coding:utf-8 -*-
"""
Author: BigCat
"""
import argparse
import json
import time
import datetime
import numpy as np
import tensorflow as tf
from config import *
from get_data import get_current_number, spider
from loguru import logger
parser = argparse.ArgumentParser()
parser.add_argument('--name', default="ssq", type=str, help="选择训练数据: 双色球/大乐透")
args = parser.parse_args()
# 关闭eager模式
tf.compat.v1.disable_eager_execution()
def load_model(name):
""" 加载模型 """
if name == "ssq":
red_graph = tf.compat.v1.Graph()
with red_graph.as_default():
red_saver = tf.compat.v1.train.import_meta_graph(
"{}red_ball_model.ckpt.meta".format(model_args[args.name]["path"]["red"])
)
red_sess = tf.compat.v1.Session(graph=red_graph)
red_saver.restore(red_sess, "{}red_ball_model.ckpt".format(model_args[args.name]["path"]["red"]))
logger.info("已加载红球模型!")
blue_graph = tf.compat.v1.Graph()
with blue_graph.as_default():
blue_saver = tf.compat.v1.train.import_meta_graph(
"{}blue_ball_model.ckpt.meta".format(model_args[args.name]["path"]["blue"])
)
blue_sess = tf.compat.v1.Session(graph=blue_graph)
blue_saver.restore(blue_sess, "{}blue_ball_model.ckpt".format(model_args[args.name]["path"]["blue"]))
logger.info("已加载蓝球模型!")
# 加载关键节点名
with open("{}/{}/{}".format(model_path, args.name, pred_key_name)) as f:
pred_key_d = json.load(f)
current_number = get_current_number(args.name)
logger.info("【{}】最近一期:{}".format(name_path[args.name]["name"], current_number))
return red_graph, red_sess, blue_graph, blue_sess, pred_key_d, current_number
else:
red_graph = tf.compat.v1.Graph()
with red_graph.as_default():
red_saver = tf.compat.v1.train.import_meta_graph(
"{}red_ball_model.ckpt.meta".format(model_args[args.name]["path"]["red"])
)
red_sess = tf.compat.v1.Session(graph=red_graph)
red_saver.restore(red_sess, "{}red_ball_model.ckpt".format(model_args[args.name]["path"]["red"]))
logger.info("已加载红球模型!")
blue_graph = tf.compat.v1.Graph()
with blue_graph.as_default():
blue_saver = tf.compat.v1.train.import_meta_graph(
"{}blue_ball_model.ckpt.meta".format(model_args[args.name]["path"]["blue"])
)
blue_sess = tf.compat.v1.Session(graph=blue_graph)
blue_saver.restore(blue_sess, "{}blue_ball_model.ckpt".format(model_args[args.name]["path"]["blue"]))
logger.info("已加载蓝球模型!")
# 加载关键节点名
with open("{}/{}/{}".format(model_path,args.name , pred_key_name)) as f:
pred_key_d = json.load(f)
current_number = get_current_number(args.name)
logger.info("【{}】最近一期:{}".format(name_path[args.name]["name"], current_number))
return red_graph, red_sess, blue_graph, blue_sess, pred_key_d, current_number
def get_year():
""" 截取年份
eg:2020-->20, 2021-->21
:return:
"""
return int(str(datetime.datetime.now().year)[-2:])
def try_error(mode, name, predict_features, windows_size):
""" 处理异常
"""
if mode:
return predict_features
else:
if len(predict_features) != windows_size:
logger.warning("期号出现跳期,期号不连续!开始查找最近上一期期号!本期预测时间较久!")
last_current_year = (get_year() - 1) * 1000
max_times = 160
while len(predict_features) != 3:
predict_features = spider(name, last_current_year + max_times, get_current_number(name), "predict")[[x[0] for x in ball_name]]
time.sleep(np.random.random(1).tolist()[0])
max_times -= 1
return predict_features
return predict_features
def get_red_ball_predict_result(red_graph, red_sess, pred_key_d, predict_features, sequence_len, windows_size):
""" 获取红球预测结果
"""
name_list = [(ball_name[0], i + 1) for i in range(sequence_len)]
data = predict_features[["{}_{}".format(name[0], i) for name, i in name_list]].values.astype(int) - 1
with red_graph.as_default():
reverse_sequence = tf.compat.v1.get_default_graph().get_tensor_by_name(pred_key_d[ball_name[0][0]])
pred = red_sess.run(reverse_sequence, feed_dict={
"inputs:0": data.reshape(1, windows_size, sequence_len),
"sequence_length:0": np.array([sequence_len] * 1)
})
return pred, name_list
def get_blue_ball_predict_result(blue_graph, blue_sess, pred_key_d, name, predict_features, sequence_len, windows_size):
""" 获取蓝球预测结果
"""
if name == "ssq":
data = predict_features[[ball_name[1][0]]].values.astype(int) - 1
with blue_graph.as_default():
softmax = tf.compat.v1.get_default_graph().get_tensor_by_name(pred_key_d[ball_name[1][0]])
pred = blue_sess.run(softmax, feed_dict={
"inputs:0": data.reshape(1, windows_size)
})
return pred
else:
name_list = [(ball_name[1], i + 1) for i in range(sequence_len)]
data = predict_features[["{}_{}".format(name[0], i) for name, i in name_list]].values.astype(int) - 1
with blue_graph.as_default():
reverse_sequence = tf.compat.v1.get_default_graph().get_tensor_by_name(pred_key_d[ball_name[1][0]])
pred = blue_sess.run(reverse_sequence, feed_dict={
"inputs:0": data.reshape(1, windows_size, sequence_len),
"sequence_length:0": np.array([sequence_len] * 1)
})
return pred, name_list
def get_final_result(red_graph, red_sess, blue_graph, blue_sess, pred_key_d, name, predict_features, mode=0):
"""" 最终预测函数
"""
m_args = model_args[name]["model_args"]
if name == "ssq":
red_pred, red_name_list = get_red_ball_predict_result(
red_graph, red_sess, pred_key_d,
predict_features, m_args["sequence_len"], m_args["windows_size"]
)
blue_pred = get_blue_ball_predict_result(
blue_graph, blue_sess, pred_key_d,
name, predict_features, 0, m_args["windows_size"]
)
ball_name_list = ["{}_{}".format(name[mode], i) for name, i in red_name_list] + [ball_name[1][mode]]
pred_result_list = red_pred[0].tolist() + blue_pred.tolist()
return {
b_name: int(res) + 1 for b_name, res in zip(ball_name_list, pred_result_list)
}
else:
red_pred, red_name_list = get_red_ball_predict_result(
red_graph, red_sess, pred_key_d,
predict_features, m_args["red_sequence_len"], m_args["windows_size"]
)
blue_pred, blue_name_list = get_blue_ball_predict_result(
blue_graph, blue_sess, pred_key_d,
name, predict_features, m_args["blue_sequence_len"], m_args["windows_size"]
)
ball_name_list = ["{}_{}".format(name[mode], i) for name, i in red_name_list] + ["{}_{}".format(name[mode], i) for name, i in blue_name_list]
pred_result_list = red_pred[0].tolist() + blue_pred[0].tolist()
return {
b_name: int(res) + 1 for b_name, res in zip(ball_name_list, pred_result_list)
}
def run(name):
""" 执行预测 """
try:
red_graph, red_sess, blue_graph, blue_sess, pred_key_d, current_number = load_model(name)
windows_size = model_args[name]["model_args"]["windows_size"]
data = spider(name, 1, current_number, "predict")
logger.info("【{}】预测期号:{}".format(name_path[name]["name"], int(current_number) + 1))
predict_features_ = try_error(1, name, data.iloc[:windows_size], windows_size)
logger.info("预测结果:{}".format(get_final_result(
red_graph, red_sess, blue_graph, blue_sess, pred_key_d, name, predict_features_))
)
except Exception as e:
logger.info("模型加载失败,检查模型是否训练,错误:{}".format(e))
if __name__ == '__main__':
if not args.name:
raise Exception("玩法名称不能为空!")
else:
run(args.name)