-
Notifications
You must be signed in to change notification settings - Fork 1
/
DoubleHeadDataset.py
53 lines (37 loc) · 1.5 KB
/
DoubleHeadDataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch.utils.data as data_utils
import torch.nn as nn
import torch
import numpy as np
import h5py
import ActionToArray
class DoubleHeadTrainingDataset(torch.utils.data.Dataset):
def __init__(self, inputs, policyOut, policyMag, valueOut):
self.features = inputs
self.targets = policyOut # .type(torch.LongTensor) for nll loss
self.targets2 = valueOut
self.numpy = policyOut.numpy()
self.targetMag = policyMag
def __getitem__(self, index):
#BinaryConverted Method!!
inArray = ActionToArray.binaryArrayToBoard(self.features[index])
# policy output vector created
array = np.zeros(2308)
array[int(self.numpy[index])] = self.targetMag[index]
output = torch.from_numpy(array)
return inArray, output, np.expand_dims(self.targets2[index], axis=0)
def __len__(self):
return len(self.features)
class DoubleHeadDataset(torch.utils.data.Dataset):
def __init__(self, inputs, policyOut, valueOut):
self.features = inputs
self.targets = policyOut # .type(torch.LongTensor) for nll loss
self.targets2 = valueOut
self.numpy = policyOut.numpy()
def __getitem__(self, index):
# output vector created
array = np.zeros(2308)
array[int(self.numpy[index])] = 1
output = torch.from_numpy(array)
return self.features[index], output, np.expand_dims(self.targets2[index], axis=0)
def __len__(self):
return len(self.features)