-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_dan.py
85 lines (75 loc) · 3.42 KB
/
main_dan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import argparse
from dan import EncoderDAN
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="")
parser.add_argument("--cca", action='store_true',
help='Use to apply the CCA linear constraint.')
parser.add_argument("--fasttext", type=str,
help='Path to the fastText embedding model.',
required=True)
parser.add_argument("--outfile", type=str,
help='Output file for the model checkpoints and training stats.',
required=True)
parser.add_argument("--gpu", type=int,
help='Index of the GPU used during training.',
required=True)
parser.add_argument("--hidden", type=int, default=38400,
help='Size of the hidden layer.')
parser.add_argument("--layers", type=int, default=1,
help='Number of hidden layers.')
parser.add_argument("--batch", type=int, default=64,
help='Batch size. Number of names in a batch.')
parser.add_argument("--dropout", type=float, default=0.5,
help='Dropout rate for training.')
parser.add_argument("--synonym_dropout", type=float, default=0.5,
help='Synonym dropout rate for the conceptual grounding constraint.')
parser.add_argument("--triplet", type=float, default=0.1,
help='Triplet margin for the siamese loss.')
parser.add_argument("--learning_rate", type=float, default=0.001,
help='Learning rate.')
args = parser.parse_args()
print('Initializing encoder...')
encoder = EncoderDAN(data_infile='data/medmentions.json',
fasttext_model_path=args.fasttext,
triplet_margin=args.triplet,
hidden_size=args.hidden,
num_layers=args.layers,
batch_size=args.batch,
learning_rate=args.learning_rate,
dropout_rate=args.dropout,
proto_dropout_rate=args.synonym_dropout,
gpu_index=args.gpu,
)
if args.cca:
print('Fitting CCA linear constraint...')
encoder.fit_cca()
print('Started training...')
encoder.train(outfile=args.outfile)
print('Finished training!')
# retrieve the best checkpoint, load checkpoint, run baseline and trained model
assert encoder.best_checkpoint
epoch_ref = encoder.best_checkpoint
print('Best model checkpoint: {}'.format(epoch_ref))
encoder.load_model('{}_{}.cpt'.format(args.outfile, epoch_ref))
print('RESULTS:')
print('TEST:')
print('1) trained DAN:')
results = encoder.synonym_retrieval_test(baseline=False)
print(results)
print('2) baseline:')
results = encoder.synonym_retrieval_test(baseline=True)
print(results)
print('ZERO-SHOT TEST:')
print('1) trained DAN:')
results = encoder.synonym_retrieval_zeroshot(baseline=False)
print(results)
print('2) baseline:')
results = encoder.synonym_retrieval_zeroshot(baseline=True)
print(results)
print('TRAIN:')
print('1) trained DAN:')
results = encoder.synonym_retrieval_train(baseline=False)
print(results)
print('2) baseline:')
results = encoder.synonym_retrieval_train(baseline=True)
print(results)