-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
278 lines (215 loc) · 8.79 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass
from typing import Optional
@dataclass
class ModelArgs:
dim: int = 4096
n_layers: int = 32
n_heads: int = 32 # n.o of head for queries
n_kv_heads: Optional[int] = None # n.o of heads for k and v
vocab_size: int = -1 # will be set when tokenizer is loaded
multiple_of: int = 256
ffn_dim_multiplier: Optional[float] = None
norm_eps: float = 1e-5
# KV cache
max_batch_size: int = 32
max_seq_len: int = 2048
device: str = None
def precompute_theta_pos_frequencies(
head_dim: int, seq_len: int, device: str, theta: float = 10000.0
):
# dim of embedding should be even as per paper
assert head_dim % 2 == 0, "Dimension must be div by 2"
# formula theta_i=10000^(-2(i-1)/dim) for i =[1,2,... dim/2]
# shape:(head_dim/2)
theta_numerator = torch.arange(0, head_dim, 2).float()
theta = 1.0 / (theta ** (theta_numerator / head_dim)).to(device)
# Position params (m)
# shape: (seq_len)
m = torch.arange(seq_len, device=device)
# multiply each theta by each position
# shape:(seq_len) outer_product *(head_dim/2) -> (seq_len,head_dim/2)
freqs = torch.outer(m, theta).float()
# compute complex numbers in polar form c= R*exp(i*m*theta)
# (seq_len,head_dim/2) ->(seq_len,head_dim/2)
freqs_complex = torch.polar(torch.ones_like(freqs), freqs)
return freqs_complex
def apply_rotary_embeddings(
x: torch.Tensor, freqs_complex: torch.Tensor, device: str
):
# (B,seq_len,H,head_dim) -> (B,seq_len,H,head_dim/2)
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
# (seq_len,head_dim/2) -> (1,seq_len,1,head_dim/2)
freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(2)
# (b,seq_len,H,Head_dim/2) *(1,seq_len,1,head_dim/2) = (B,seq_len,H,head_dim/2)
x_rotated = x_complex * freqs_complex
# (B,seq_len,H,Head_dim/2) ->(B,seq_len,H,Head_dim/2,2)
x_out = torch.view_as_real(x_rotated)
# (B,seq_len,H,head_dim/2,2)->(B,seq_len,H,head_dim)
x_out = x_out.reshape(*x.shape)
return x_out.type_as(x).to(device)
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
batch_size, seq_len, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
else:
return (
x[:, :, :, None, :]
.expand(batch_size, seq_len, n_kv_heads, n_rep, head_dim)
.reshape(batch_size, seq_len, n_kv_heads * n_rep, head_dim)
)
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x: torch.Tensor):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x: torch.Tensor):
return self.weight * self._norm(x.float()).type_as(x)
class SelfAttention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
# n.o of heads for key and values
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
# n.o of heads for queries.
self.n_heads_q = args.n_heads
# how many times heads of k and v should be repated to match heads of q
self.n_rep = self.n_heads_q // self.n_kv_heads
# dim of each head
self.head_dim = args.dim // args.n_heads
# print(self.head_dim)
# print(self.n_kv_heads)
self.wq = nn.Linear(args.dim, args.n_heads*self.head_dim, bias=False)
self.wk = nn.Linear(
args.dim, self.n_kv_heads*self.head_dim, bias=False
)
self.wv = nn.Linear(
args.dim, self.n_kv_heads*self.head_dim, bias=False
)
self.wo = nn.Linear(args.n_heads*self.head_dim, args.dim, bias=False)
self.cache_k = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_kv_heads,
self.head_dim,
)
)
self.cache_v = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_kv_heads,
self.head_dim,
)
)
def forward(
self, x: torch.Tensor, start_pos: int, freqs_complex: torch.Tensor
):
batch_size, seq_len, _ = x.shape
xq = self.wq(x)
xk = self.wk(x)
xv = self.wv(x)
# (B,1,H_Q*Head_Dim) -> (B,1,H_Q,Head_dim)
xq = xq.view(batch_size, seq_len, self.n_heads_q, self.head_dim)
# (B,1,H_kv*Head_Dim) -> (B,1,H_kv,Head_dim)
xk = xk.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)
xv = xv.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)
xq = apply_rotary_embeddings(xq, freqs_complex, device=x.device)
xk = apply_rotary_embeddings(xk, freqs_complex, device=x.device)
# Replace the entry in the cache for this token
self.cache_k[:batch_size, start_pos : start_pos + seq_len] = xk
self.cache_v[:batch_size, start_pos : start_pos + seq_len] = xv
# Retrive all caches keys and val
keys = self.cache_k[:batch_size, : start_pos + seq_len]
values = self.cache_v[:batch_size, : start_pos + seq_len]
# repeat heads
keys = repeat_kv(keys, self.n_rep)
values = repeat_kv(values, self.n_rep)
xq = xq.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(
self.head_dim
)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values)
output = (
output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
)
return self.wo(output)
class FeedForward(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
hidden_dim = 4 * args.dim
hidden_dim = int(2 * hidden_dim / 3)
if args.ffn_dim_multiplier is not None:
hidden_dim = int(args.ffn_dim_multiplier * hidden_dim)
# round hidden dim to nearest numtilple_of
hidden_dim = args.multiple_of * (
(hidden_dim + args.multiple_of - 1) // args.multiple_of
)
self.w1 = nn.Linear(args.dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, args.dim, bias=False)
self.w3 = nn.Linear(args.dim, hidden_dim, bias=False)
def forward(self, x: torch.Tensor):
swish = F.silu(self.w1(x))
x_v = self.w3(x)
x = swish * x_v
x = self.w2(x)
return x
class EncoderBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.dim // args.n_heads
self.attention = SelfAttention(args)
self.feed_forward = FeedForward(args)
# Normalize Before self attention
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
# Normalize before feed forward
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
def forward(
self, x: torch.Tensor, start_pos: int, freqs_complex: torch.Tensor
):
h = x + self.attention.forward(
self.attention_norm(x), start_pos, freqs_complex
)
out = h + self.feed_forward.forward(self.ffn_norm(h))
return out
class Transformer(nn.Module):
def __init__(self, args: ModelArgs) -> None:
super().__init__()
assert args.vocab_size != -1, "Vocab size is not set"
self.args = args
self.vocab_size = args.vocab_size
self.n_layers = args.n_layers
self.tok_embeddings = nn.Embedding(self.vocab_size, args.dim)
self.layers = nn.ModuleList()
for _ in range(args.n_layers):
self.layers.append(EncoderBlock(args))
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
self.output = nn.Linear(args.dim, self.vocab_size, bias=False)
self.freqs_complex = precompute_theta_pos_frequencies(
self.args.dim // self.args.n_heads,
self.args.max_seq_len * 2,
device=self.args.device
)
def forward(self, tokens: torch.Tensor, start_pos: int):
batch_size, seq_len = tokens.shape
assert seq_len == 1, "Only one token at a time can be processed"
h = self.tok_embeddings(tokens) # (B,seq_len) -> (b,seq_len,dim)
freqs_complex = self.freqs_complex[
start_pos : start_pos + seq_len
] # get (m,theta) corresponding to the positions
for layer in self.layers:
h = layer(h, start_pos, freqs_complex)
h = self.norm(h)
output = self.output(h).float()
return output
# Llama2 embedding -> 4096