-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_models.py
76 lines (63 loc) · 1.99 KB
/
test_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
import numpy as np
from codebase.models.BBBTimeSeriesPredModel import BBBTimeSeriesPredModel
from codebase.train import train
import data.data_utils as data_ut
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 50
n_batches = 100 # used by dummy data
n_input_steps = 50
n_pred_steps = 20
input_feat_dim = 2 # 4 for stocks data
pred_feat_dim = 2 # 1 for stocks data
# Network
hidden_feat_dim = 80
pi = 0.3
std1 = np.exp(1)
std2 = np.exp(-6)
gpu = False
BBB = True
iter_max = 50000
training = True
sharpen = True
full_seq_len = n_input_steps + n_pred_steps
model = BBBTimeSeriesPredModel(
pi=pi,
std1=std1,
std2=std2,
gpu=gpu,
BBB=BBB,
training=training,
sharpen=sharpen,
input_feat_dim=input_feat_dim,
pred_feat_dim = pred_feat_dim,
hidden_feat_dim=hidden_feat_dim,
n_input_steps=n_input_steps,
n_pred_steps=n_pred_steps,
name='test_BBBRNN',
device=device).to(device)
# np.random.seed(1234)
# inputs = torch.Tensor(np.random.rand(n_input_steps, batch_size, input_feat_dim))
# hidden = model.init_hidden(batch_size)
# targets = torch.Tensor(np.random.rand(n_pred_steps, batch_size, pred_feat_dim))
# outputs, hidden = model.forward(inputs, hidden, targets)
# assert outputs.shape == torch.Size([n_pred_steps, batch_size, 2 * pred_feat_dim])
# print(outputs.shape)
# loss = model.get_loss(outputs, targets)
# print(loss)
dummy_training_set = data_ut.dummy_data_creator(
batch_size=batch_size,
n_batches=n_batches,
input_feat_dim=input_feat_dim,
n_input_steps=n_input_steps,
n_pred_steps=n_pred_steps,
kernel=data_ut.sinusoidal_kernel,
device=device)
train(model, dummy_training_set, batch_size, n_batches, device,
kernel=data_ut.sinusoidal_kernel,
lr=1e-3,
clip_grad=5,
iter_max=iter_max,
iter_save=np.inf,
iter_plot=np.inf,
reinitialize=False)