-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdistmult_eval.py
executable file
·75 lines (61 loc) · 1.83 KB
/
distmult_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import openke
from openke.config import Trainer, Tester
from openke.module.model import DistMult
from openke.module.loss import SoftplusLoss
from openke.module.strategy import NegativeSampling
from openke.data import TrainDataLoader, TestDataLoader
import joblib
import torch
import numpy as np
from collections import defaultdict
import argparse
import os
import sys
import timeit
from data import (
TASK_REV_MEDIUMHAND,
TASK_LABELS,
)
import metrics
if not os.path.exists('checkpoint'):
os.makedirs('checkpoint')
# dataloader for training
train_dataloader = TrainDataLoader(
in_path = "./data/kge/openke/",
nbatches = 100,
threads = 8,
sampling_mode = "normal",
bern_flag = 1,
filter_flag = 1,
neg_ent = 25,
neg_rel = 0
)
# dataloader for test
test_dataloader = TestDataLoader("./data/kge/openke/", "link")
# define the model
distmult = DistMult(
ent_tot = train_dataloader.get_ent_tot(),
rel_tot = train_dataloader.get_rel_tot(),
dim = 200
)
# define the loss function
model = NegativeSampling(
model = distmult,
loss = SoftplusLoss(),
batch_size = train_dataloader.get_batch_size(),
regul_rate = 1.0
)
start_time = timeit.default_timer()
# train the model
trainer = Trainer(model = model, data_loader = train_dataloader, train_times = 2000, alpha = 0.5, use_gpu = True, opt_method = "adagrad")
trainer.run()
distmult.save_checkpoint('./checkpoint/distmult.ckpt')
stop_time = timeit.default_timer()
print('average training time: {}'.format((stop_time-start_time)/2000))
start_time = timeit.default_timer()
# test the model
distmult.load_checkpoint('./checkpoint/distmult.ckpt')
tester = Tester(model = distmult, data_loader = test_dataloader, use_gpu = True)
tester.run_link_prediction(type_constrain = False)
stop_time = timeit.default_timer()
print('link prediction testing time: {}'.format((stop_time-start_time)))