-
Notifications
You must be signed in to change notification settings - Fork 104
/
Copy pathtest.py
77 lines (69 loc) · 3.07 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""
@author: Viet Nguyen <[email protected]>
"""
import os
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from src.utils import get_evaluation
from src.dataset import MyDataset
import argparse
import shutil
import csv
import numpy as np
def get_args():
parser = argparse.ArgumentParser(
"""Implementation of the model described in the paper: Hierarchical Attention Networks for Document Classification""")
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--data_path", type=str, default="data/test.csv")
parser.add_argument("--pre_trained_model", type=str, default="trained_models/whole_model_han")
parser.add_argument("--word2vec_path", type=str, default="data/glove.6B.50d.txt")
parser.add_argument("--output", type=str, default="predictions")
args = parser.parse_args()
return args
def test(opt):
test_params = {"batch_size": opt.batch_size,
"shuffle": False,
"drop_last": False}
if os.path.isdir(opt.output):
shutil.rmtree(opt.output)
os.makedirs(opt.output)
if torch.cuda.is_available():
model = torch.load(opt.pre_trained_model)
else:
model = torch.load(opt.pre_trained_model, map_location=lambda storage, loc: storage)
test_set = MyDataset(opt.data_path, opt.word2vec_path, model.max_sent_length, model.max_word_length)
test_generator = DataLoader(test_set, **test_params)
if torch.cuda.is_available():
model.cuda()
model.eval()
te_label_ls = []
te_pred_ls = []
for te_feature, te_label in test_generator:
num_sample = len(te_label)
if torch.cuda.is_available():
te_feature = te_feature.cuda()
te_label = te_label.cuda()
with torch.no_grad():
model._init_hidden_state(num_sample)
te_predictions = model(te_feature)
te_predictions = F.softmax(te_predictions)
te_label_ls.extend(te_label.clone().cpu())
te_pred_ls.append(te_predictions.clone().cpu())
te_pred = torch.cat(te_pred_ls, 0).numpy()
te_label = np.array(te_label_ls)
fieldnames = ['True label', 'Predicted label', 'Content']
with open(opt.output + os.sep + "predictions.csv", 'w') as csv_file:
writer = csv.DictWriter(csv_file, fieldnames=fieldnames, quoting=csv.QUOTE_NONNUMERIC)
writer.writeheader()
for i, j, k in zip(te_label, te_pred, test_set.texts):
writer.writerow(
{'True label': i + 1, 'Predicted label': np.argmax(j) + 1, 'Content': k})
test_metrics = get_evaluation(te_label, te_pred,
list_metrics=["accuracy", "loss", "confusion_matrix"])
print("Prediction:\nLoss: {} Accuracy: {} \nConfusion matrix: \n{}".format(test_metrics["loss"],
test_metrics["accuracy"],
test_metrics["confusion_matrix"]))
if __name__ == "__main__":
opt = get_args()
test(opt)