-
Notifications
You must be signed in to change notification settings - Fork 0
/
transformer.py
144 lines (116 loc) · 5.55 KB
/
transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from transformers import AutoConfig
from transformers.models.bert.modeling_bert import BertEncoder, BertModel
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from nn_utils import timestep_embedding
class TransformerModel(nn.Module):
"""
Transformer model class with an LM head same as the embedding matrix which maps
dicrete tokens to continuous space.
"""
def __init__(self,
input_dims,
output_dims,
hidden_t_dim,
dropout=0,
config=None,
config_name="bert-base-uncased",
vocab_size=None,
init_pretrained="no",
logits_mode=1) -> None:
super().__init__()
if config is None:
config = AutoConfig.from_pretrained(config_name)
config.hidden_dropout_prob = dropout
self.input_dims = input_dims
self.output_dims = output_dims
self.hidden_t_dim = hidden_t_dim
self.dropout = dropout
self.logits_mode = logits_mode
self.hidden_size = config.hidden_size
"""
word_embedding -> maps discrete tokens to continuous embeddings
lm_head -> maps output embeddings back to vocabulary (tokens)
"""
self.word_embedding = nn.Embedding(vocab_size, self.input_dims)
self.lm_head = nn.Linear(self.input_dims, vocab_size)
#weight sharing
with torch.no_grad():
self.lm_head.weight = self.word_embedding.weight #lm-head and word embeddings are shared
time_embed_dim = hidden_t_dim * 4
self.time_embed = nn.Sequential(
nn.Linear(hidden_t_dim, time_embed_dim),
nn.GELU(),
nn.Linear(time_embed_dim, self.hidden_size)
)
if self.input_dims != config.hidden_size:
self.input_up_proj = nn.Sequential(nn.Linear(input_dims, config.hidden_size),
nn.Tanh(),
nn.Linear(config.hidden_size, config.hidden_size))
if init_pretrained == 'bert':
print("Using pretrained BERT weights.....")
print(config)
lm = BertModel.from_pretrained(config_name, config=config)
self.word_embedding = lm.embeddings.word_embeddings # set word embeddings to pretrained BERT embeddings
with torch.no_grad():
self.lm_head.weight = self.word_embedding.weight # lm-head and word embeddings are shared
self.encoder = lm.encoder
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embeddings = lm.embeddings.position_embeddings
self.LayerNorm = lm.embeddings.LayerNorm
del lm.embeddings
del lm.pooler
elif init_pretrained == 'no':
self.encoder = BertEncoder(config=config)
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
else:
assert False, "Pretrained type not supported."
self.dropout = nn.Dropout(config.hidden_dropout_prob)
if self.output_dims != config.hidden_size:
self.output_down_proj = nn.Sequential(nn.Linear(config.hidden_size, config.hidden_size),
nn.Tanh(),
nn.Linear(config.hidden_size, self.output_dims))
def get_embeddings_from_input_ids(self, input_ids):
return self.word_embedding(input_ids)
def get_logits(self, hidden_repr):
"""
Given output hidden representations, obtain logits over the vocabulary
"""
if self.logits_mode == 1:
return self.lm_head(hidden_repr)
elif self.logits_mode == 2:
raise NotImplementedError
else:
raise NotImplementedError
def forward(self, x, timesteps):
"""
Apply the transformer model to an input batch.
-> x: an [B x seq_len x ...] Tensor of inputs.
-> timesteps: a 1-D batch of timesteps.
:return: an [B x seq_len x ...] Tensor of outputs.
"""
# obtain time embedding of dimensions N X hidden_t_dim
emb_t = self.time_embed(timestep_embedding(timesteps, self.hidden_t_dim))
if self.input_dims != self.hidden_size:
emb_x = self.input_up_proj(x)
else:
emb_x = x
seq_len = x.size(1)
position_ids = self.position_ids[:, :seq_len]
"""
position_ids -> [B X seq_len] -> convert to embeddings -> [B X seq_len X hidden_size]
emb_x -> [B X seq_len X hidden_size]
emb_t -> [B X hidden_size] -> reshape and expand it to [B X seq_len X hidden_size]
"""
emb_inputs = self.position_embeddings(position_ids) + \
emb_x + \
emb_t.unsqueeze(1).expand(-1, seq_len, 1)
output_embeddings = self.encoder(emb_inputs).last_hidden_state
if self.output_dims != self.hidden_size:
output_embeddings = self.output_down_proj(output_embeddings)
output_embeddings = output_embeddings.type(x.dtype)
return output_embeddings