-
Notifications
You must be signed in to change notification settings - Fork 1
/
TrainFullPolicyNetwork.py
85 lines (64 loc) · 2.97 KB
/
TrainFullPolicyNetwork.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
import numpy as np
import torch.nn as nn
import torch.utils.data as data_utils
from ChessConvNet import ChessConvNet
import ChessResNet
from PolicyDataset import PolicyDataset
from FullPolicyDataset import FullPolicyDataset
import h5py
# inputs and outputs are numpy arrays. This method of checking accuracy only works with imported games.
# if it's not imported, accuracy will never be 100%, so it will just output the trained network after 10,000 epochs.
def trainFullPolicyNetwork(EPOCHS=1, BATCH_SIZE=1, LR=0.001,
loadDirectory='none.pt',
saveDirectory='network1.pt'):
data = FullPolicyDataset()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
trainLoader = torch.utils.data.DataLoader(dataset=data, batch_size=BATCH_SIZE, shuffle=True)
# to create a prediction, create a new dataset with input of the states, and output should just be np.zeros()
# this is a residual network
model = ChessResNet.PolicyResNetMain().double()
try:
model = torch.load(loadDirectory)
except:
print("Pretrained NN model not found!")
criterion = nn.NLLLoss() # MSELoss // PoissonNLLLoss //
optimizer = torch.optim.Adam(model.parameters(), lr=LR) # , weight_decay=0.00001)
total_step = len(trainLoader)
trainNotFinished = True
for epoch in range(EPOCHS):
if trainNotFinished:
for i, (images, labels) in enumerate(trainLoader):
images = images.to(device)
labels = labels.to(device)
# Forward pass
outputMoves = model(images)
loss = criterion(outputMoves, labels)
if (i + 1) % 150 == 0:
# find predicted labels
values = np.exp((model(images).data.detach().numpy()))
print("MAX:", np.amax(np.amax(values, axis=1)))
print("MIN:", np.amin(np.amin(values, axis=1)))
_, predicted = torch.max(model(images).data, 1)
predicted = predicted.numpy()
print(predicted)
actual = labels.data
actual = actual.numpy()
print(actual)
print("Correct:", (predicted == actual).sum())
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i + 1) % 1 == 0:
print('Epoch [{}/{}], Step [{}/{}], Loss: {:.6f}'
.format(epoch + 1, EPOCHS, i + 1, total_step, loss.item()))
if (i + 1) % 200 == 0:
torch.save(model, saveDirectory)
print("Updated!")
torch.save(model, saveDirectory)
# PROS: WILL TRAIN NETWORK WITHOUT USING ANY RAM
# CONS: IS VERY SLOW
trainFullPolicyNetwork(loadDirectory="",
saveDirectory="New Networks/18011810-FULL-POLICY.pt", EPOCHS=1,
BATCH_SIZE=1, LR=0.001)