-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathrun_adv_wgan_gp.py
76 lines (71 loc) · 2.23 KB
/
run_adv_wgan_gp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
from mmkgc.config import Tester, WCGTrainerGP
from mmkgc.module.model import AdvRelRotatE
from mmkgc.module.loss import SigmoidLoss
from mmkgc.module.strategy import NegativeSamplingGP
from mmkgc.data import TrainDataLoader, TestDataLoader
from mmkgc.adv.modules import CombinedGenerator
from args import get_args
if __name__ == "__main__":
args = get_args()
print(args)
# set the seed
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
# dataloader for training
train_dataloader = TrainDataLoader(
in_path="./benchmarks/" + args.dataset + '/',
batch_size=args.batch_size,
threads=8,
sampling_mode="normal",
bern_flag=1,
filter_flag=1,
neg_ent=args.neg_num,
neg_rel=0
)
# dataloader for test
test_dataloader = TestDataLoader(
"./benchmarks/" + args.dataset + '/', "link")
img_emb = torch.load('./embeddings/' + args.dataset + '-visual.pth')
text_emb = torch.load('./embeddings/' + args.dataset + '-textual.pth')
# define the model
kge_score = AdvRelRotatE(
ent_tot=train_dataloader.get_ent_tot(),
rel_tot=train_dataloader.get_rel_tot(),
dim=args.dim,
margin=args.margin,
epsilon=2.0,
img_emb=img_emb,
text_emb=text_emb
)
print(kge_score)
# define the loss function
model = NegativeSamplingGP(
model=kge_score,
loss=SigmoidLoss(adv_temperature=args.adv_temp),
batch_size=train_dataloader.get_batch_size(),
regul_rate=0.00001
)
adv_generator = CombinedGenerator(
noise_dim=64,
structure_dim=2*args.dim,
img_dim=2*args.dim
)
# train the model
trainer = WCGTrainerGP(
model=model,
data_loader=train_dataloader,
train_times=args.epoch,
alpha=args.learning_rate,
use_gpu=True,
opt_method='Adam',
generator=adv_generator,
lrg=args.lrg,
mu=args.mu
)
trainer.run()
kge_score.save_checkpoint(args.save)
# test the model
kge_score.load_checkpoint(args.save)
tester = Tester(model=kge_score, data_loader=test_dataloader, use_gpu=True)
tester.run_link_prediction(type_constrain=False)