-
Notifications
You must be signed in to change notification settings - Fork 6
/
models.py
68 lines (53 loc) · 2.21 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
import torch.nn as nn
from transformers import AutoModel
# set device to use
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
HIDDEN_SIZE = 512
NUM_CLASS = 3
ENTAILMEN_LABEL = 0
NEUTRAL_LABEL = 1
CONTRADICTION_LABEL = 2
class SentBert(nn.Module):
def __init__(self, input_dim, output_dim, tokenizer):
super(SentBert, self).__init__()
# Initiate bert model from huggingface
self.bert_model = AutoModel.from_pretrained(
"google/bert_uncased_L-8_H-512_A-8")
self.bert_model.train()
# Linear Layers
self.linear1 = nn.Linear(input_dim, input_dim//2)
self.linear2 = nn.Linear(input_dim//2, output_dim)
# Tokenizer
self.tokenizer = tokenizer
def forward(self, sent1, attn_mask1, sent2, attn_mask2):
# N x T x hidden_size
N, T1 = sent1.shape
_, T2 = sent2.shape
out1 = self.bert_model(sent1, attention_mask=attn_mask1)
out2 = self.bert_model(sent2, attention_mask=attn_mask2)
H = out1['last_hidden_state'].shape[-1]
# Pooling
hidden_states1 = out1['last_hidden_state'] # (N x T x H)
hidden_states1 = hidden_states1 * torch.reshape(attn_mask1, (N, T1, 1))
hidden_states2 = out2['last_hidden_state']
hidden_states2 = hidden_states2 * torch.reshape(attn_mask2, (N, T2, 1))
embedding1 = torch.mean(hidden_states1[:, 1:, :], axis=1) # N x H
embedding2 = torch.mean(hidden_states2[:, 1:, :], axis=1)
# Concate embeddings (u, v, |u-v|)
diff = torch.abs(embedding1-embedding2)
merged = torch.cat((embedding1, embedding2, diff), -1)
merged = self.linear1(merged)
merged = self.linear2(merged) # N x class
return merged, (embedding1, embedding2)
def encode(self, sents):
self.bert_model.eval()
self.eval()
with torch.no_grad():
encoded_sent1 = self.tokenizer(sents, padding=True, truncation=True)
input_ids = torch.Tensor(encoded_sent1['input_ids']).long()
attn_mask = torch.Tensor(encoded_sent1['attention_mask']).long()
out = self.bert_model(input_ids, attention_mask=attn_mask)
embeddings = torch.mean(
out['last_hidden_state'][:, 1:, :], axis=1) # N x hidden_size
return embeddings.detach()