-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtraining.py
157 lines (121 loc) · 5.32 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from vit_model import ViT
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from tqdm import tqdm
from utils import save_params_to_json, read_config
import time
torch.manual_seed(0)
def training():
#########################
# read hyperparameters from config file
config = read_config('hyper_params.json')
patch_size = config['patch_size']
pos_encoding_learnable=config['pos_encoding_learnable']
token_dim=config['token_dim']
n_heads=config['n_heads']
encoder_blocks=config['encoder_blocks']
mlp_dim=config['mlp_dim']
batch_size =config['batch_size']
n_epochs =config['n_epochs']
lr=config['lr']
best_path = config['best_path']
log_filename = config['log_path']
# Constants
image_size = 28
channels = 1
num_classes = 10
#########################
# Define transformations to apply to the data
transform = transforms.Compose([
transforms.ToTensor(), # Convert PIL Image to tensor
# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalize the image data
])
# Load MNIST dataset
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# Create data loaders
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
# Define the classes of the dataset
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# Define the device to use
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {torch.cuda.get_device_name(device)}" if device.type == "cuda" else "Device: CPU")
# Create the model and move it to the device (GPU if available)
model = ViT(image_size=image_size, patch_size=patch_size,
num_classes=num_classes, channels=channels,
pos_encoding_learnable=pos_encoding_learnable,
token_dim=token_dim, mlp_dim=mlp_dim,
encoder_blocks=encoder_blocks, n_heads=n_heads).to(device)
optimizer = Adam(model.parameters(), lr=lr)
loss = CrossEntropyLoss()
epoch_times = []
training_losses = []
testing_losses = []
testing_accuracies = []
for epoch in tqdm(range(n_epochs)):
model.train()
train_loss = 0
# Start epoch timer
start_time = time.time()
for i, data in tqdm(enumerate(trainloader), desc=f'Epoch {epoch + 1} training', leave=False):
optimizer.zero_grad()
inputs, labels = data[0].to(device), data[1].to(device)
outputs = model(inputs)
loss_value = loss(outputs, labels)
loss_value.backward()
optimizer.step()
train_loss += loss_value.detach().cpu().item()
end_time = time.time()
epoch_times.append(end_time - start_time)
t_loss = train_loss / len(trainloader)
training_losses.append(t_loss)
print(f'Epoch {epoch + 1} training loss: {t_loss:.4f}')
model.eval()
test_loss = 0
correct, total = 0, 0
with torch.no_grad():
for i, data in tqdm(enumerate(testloader), desc=f'Epoch {epoch + 1} testing', leave=False):
inputs, labels = data[0].to(device), data[1].to(device)
outputs = model(inputs)
loss_value = loss(outputs, labels)
# calculate loss
test_loss += loss_value.detach().cpu().item()
# calculate accuracy
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
test_accuracy = 100 * correct / total
t_loss = test_loss / len(testloader)
testing_losses.append(t_loss)
testing_accuracies.append(test_accuracy)
print(f'Epoch {epoch + 1} testing accuracy: {test_accuracy:.2f}%')
print(f'Epoch {epoch + 1} testing loss: {t_loss:.4f}')
print('Finished Training')
# Save the model
print(f'Saving model to {best_path}')
torch.save(model.state_dict(), best_path)
print(f'Saving session to {log_filename}')
save_params_to_json(log_filename,
config={
"patch_size": patch_size,
"pos_encoding_learnable": pos_encoding_learnable,
"token_dim": token_dim,
"n_heads": n_heads,
"encoder_blocks": encoder_blocks,
"mlp_dim": mlp_dim,
"batch_size": batch_size,
"lr": lr,
},
metrics={
"training_losses": training_losses,
"testing_losses": testing_losses,
"times_per_epoch": epoch_times,
"testing_accuracies": testing_accuracies
})
if __name__ == "__main__":
training()