-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathopenke_triple_classification.py
98 lines (83 loc) · 2.67 KB
/
openke_triple_classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import openke
from openke.module.model import TransE, TransD, RESCAL, DistMult, ComplEx, SimplE
import joblib
import torch
import numpy as np
from collections import defaultdict
import argparse
import os
import sys
import timeit
import time
from data import (
TASK_REV_MEDIUMHAND,
TASK_LABELS,
)
import metrics
from utils import Task, openke_predict, get_entity_relationship_dicts
parser = argparse.ArgumentParser()
parser.add_argument("--model", default='transe')
args = parser.parse_args()
ent_list, rel_list = get_entity_relationship_dicts()
if args.model=='transe':
model = TransE(
ent_tot = len(ent_list),
rel_tot = len(rel_list),
dim = 200,
p_norm = 1,
norm_flag = True)
elif args.model=='transd':
model = TransD(
ent_tot = len(ent_list),
rel_tot = len(rel_list),
dim_e = 200,
dim_r = 200,
p_norm = 1,
norm_flag = True)
elif args.model=='rescal':
model = RESCAL(
ent_tot = len(ent_list),
rel_tot = len(rel_list),
dim = 50)
elif args.model=='distmult':
model = DistMult(
ent_tot = len(ent_list),
rel_tot = len(rel_list),
dim = 200)
elif args.model=='complex':
model = ComplEx(
ent_tot = len(ent_list),
rel_tot = len(rel_list),
dim = 200)
elif args.model=='simple':
model = SimplE(
ent_tot = len(ent_list),
rel_tot = len(rel_list),
dim = 200)
model = model.cpu()
start_time = timeit.default_timer()
task_names = ['situated-OP', 'situated-OA', 'situated-AP']
for task_name in task_names:
print('{} task'.format(task_name))
task = Task(TASK_REV_MEDIUMHAND[task_name])
samples = task.get_test_examples()
y_hat = []
y = []
names = []
for sample in samples:
names.append(sample.name)
head, tail = sample.name.split('/')
if task_name=='situated-OP':
res = openke_predict(model, np.array(ent_list[head+'-o']), np.array(ent_list[tail+'-p']), np.array([[0],[1],[2]]), 0)
elif task_name=='situated-OA':
res = openke_predict(model, np.array(ent_list[head+'-o']), np.array(ent_list[tail+'-a']), np.array([[0],[1],[2]]), 1)
elif task_name=='situated-AP':
res = openke_predict(model, np.array(ent_list[head+'-a']), np.array(ent_list[tail+'-p']), np.array([[0],[1],[2]]), 2)
y_hat.append(res)
y.append(int(sample.label))
y = np.array(y)
y_hat = np.array(y_hat)
txt = metrics.report(y_hat, y, names, TASK_LABELS[TASK_REV_MEDIUMHAND[task_name]])
print(txt)
stop_time = timeit.default_timer()
print('triple classification testing time: {}'.format((stop_time-start_time)))