-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathutils.py
112 lines (93 loc) · 5.11 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
import torch.optim as optim
import torch.nn as nn
import models
from torch.utils.data import DataLoader, TensorDataset
device = 'cuda'
def load_data(args):
input_train = torch.load('./data/'+args.dataset+'/train/input_train.pt')
target_train = torch.load('./data/'+args.dataset+'/train/target_train.pt')
if args.test_val_train == 'test':
input_val = torch.load('./data/'+args.dataset+'/test/input_test.pt')
target_val = torch.load('./data/'+args.dataset+'/test/target_test.pt')
elif args.test_val_train == 'val':
input_val = torch.load('./data/'+args.dataset+'/val/input_val.pt')
target_val = torch.load('./data/'+args.dataset+'/val/target_val.pt')
elif args.test_val_train == 'train':
input_val = input_train
target_val = target_train
#define dimesions
global train_shape_in , train_shape_out, val_shape_in, val_shape_in
train_shape_in = input_train.shape
train_shape_out = target_train.shape
val_shape_in = input_val.shape
val_shape_out = target_val.shape
#mean, std, max
mean = target_train.mean()
std = target_train.std()
global max_val, min_val
max_val = torch.zeros((args.dim_channels,1))
min_val = torch.zeros((args.dim_channels,1))
for i in range(args.dim_channels):
max_val[i] = target_train[:,0,i,...].max()
min_val[i] = target_train[:,0,i,...].min()
#transform data
for i in range(args.dim_channels):
input_train[:,0,i,...] = (input_train[:,0,i,...]-min_val[i]) /(max_val[i]-min_val[i])
target_train[:,0,i,...] = (target_train[:,0,i,...] -min_val[i])/(max_val[i]-min_val[i])
input_val[:,0,i,...] = (input_val[:,0,i,...]-min_val[i])/(max_val[i]-min_val[i])
target_val[:,0,i,...] = (target_val[:,0,i,...]-min_val[i])/(max_val[i]-min_val[i])
train_data = TensorDataset(input_train, target_train)
val_data = TensorDataset(input_val, target_val)
train = DataLoader(train_data, batch_size=args.batch_size, shuffle=True)
val = DataLoader(val_data, batch_size=args.batch_size, shuffle=False)
return [train, val, mean, std, max_val, train_shape_in, train_shape_out, val_shape_in, val_shape_out]
def load_model(args, discriminator=False):
if discriminator:
model = models.Discriminator()
else:
if args.model == 'convgru':
model = models.ConvGRUGeneratorDet( number_channels=args.number_channels, number_residual_blocks=args.number_residual_blocks, upsampling_factor=args.upsampling_factor, time_steps=3, constraints=args.constraints, cwindow_size=args.constraints_window_size)
elif args.model == 'flowconvgru':
model = models.TimeEndToEndModel( number_channels=args.number_channels, number_residual_blocks=args.number_residual_blocks, upsampling_factor=args.upsampling_factor, time_steps=3, constraints=args.constraints)
elif args.model == 'gan':
model = models.ResNet(number_channels=args.number_channels, number_residual_blocks=args.number_residual_blocks, upsampling_factor=args.upsampling_factor, noise=(args.model=='gan'), constraints=args.constraints, dim=args.dim_channels)
elif args.model == 'cnn':
model = models.ResNet(number_channels=args.number_channels, number_residual_blocks=args.number_residual_blocks, upsampling_factor=args.upsampling_factor, noise=(args.model=='gan'), constraints=args.constraints, dim=args.dim_channels)
model = model.to(device)
return model
def get_optimizer(args, model):
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
return optimizer
def get_criterion(args, discriminator=False):
if discriminator:
criterion = nn.BCELoss()
else:
criterion = nn.MSELoss()
return criterion
def mass_loss(output, in_val, args):
ds_out = torch.nn.functional.avg_pool2d(output[:,0,0,:,:], args.upsampling_factor)
return torch.nn.functional.mse_loss(ds_out, in_val)
def get_loss(output, true_value, in_val, args):
if args.loss == 'mass_constraints':
return args.alpha*mass_loss(output, in_val[:,0,0,...], args) + (1-args.alpha)*torch.nn.functional.mse_loss(output, true_value)
else:
return torch.nn.functional.mse_loss(output, true_value)
def process_for_training(inputs, targets):
inputs = inputs.to(device)
targets = targets.to(device)
return inputs, targets
def process_for_eval(outputs, targets, mean, std, max_val, args):
if args.model == 'gan':
outputs[:,:,0,0,...] = outputs[:,0,0,...]*(max_val[0].to(device)-min_val[0].to(device))+min_val[0].to(device)
targets[:,0,0,...] = targets[:,0,0,...]*(max_val[0].to(device)-min_val[0].to(device))+min_val[0].to(device)
else:
for i in range(args.dim_channels):
outputs[:,0,i,...] = outputs[:,0,i,...]*(max_val[i].to(device)-min_val[i].to(device))+min_val[i].to(device)
targets[:,0,i,...] = targets[:,0,i,...]*(max_val[i].to(device)-min_val[i].to(device))+min_val[i].to(device)
return outputs, targets
def is_gan(args):
if args.model == 'gan':
return True
else:
return False