forked from yurakuratov/t5-experiments
-
Notifications
You must be signed in to change notification settings - Fork 59
/
adapters.py
63 lines (50 loc) · 2.23 KB
/
adapters.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import math
import torch
from torch import nn
# slightly modified implementation from
# https://github.com/jxhe/unify-parameter-efficient-tuning/blob/3222ce2c0079566a28043e22380eb4ab6ad14389/petl/petl_factory.py#L396
class Adapter_Layer(nn.Module):
def __init__(self, config=None):
super().__init__()
self.n_embd = config.n_embd
self.down_size = config.adapter_bottleneck_dim
self.adapter_dropout = config.adapter_dropout
self.adapter_scale = config.adapter_scale
self.adapter_layernorm_option = getattr(config, 'adapter_layernorm_option', 'in')
# self.non_linearity = args.non_linearity # use ReLU by default
self.adapter_layer_norm_before = None
if self.adapter_layernorm_option == "in" or self.adapter_layernorm_option == "out":
self.adapter_layer_norm_before = nn.LayerNorm(self.n_embd)
if self.adapter_scale == "learnable_scalar":
self.scale = nn.Parameter(torch.ones(1))
else:
self.scale = float(self.adapter_scale)
self.down_proj = nn.Linear(self.n_embd, self.down_size)
self.non_linear_func = nn.ReLU()
self.up_proj = nn.Linear(self.down_size, self.n_embd)
if self.adapter_dropout > 0:
self.dropout = nn.Dropout(p=self.adapter_dropout)
else:
self.lora_dropout = lambda x: x
# init params with lora init
with torch.no_grad():
nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5))
nn.init.zeros_(self.up_proj.weight)
nn.init.zeros_(self.down_proj.bias)
nn.init.zeros_(self.up_proj.bias)
def forward(self, x, add_residual=True, residual=None):
residual = x if residual is None else residual
if self.adapter_layernorm_option == 'in':
x = self.adapter_layer_norm_before(x)
down = self.down_proj(x)
down = self.non_linear_func(down)
down = self.dropout(down)
up = self.up_proj(down)
up = up * self.scale
if self.adapter_layernorm_option == 'out':
up = self.adapter_layer_norm_before(up)
if add_residual:
output = up + residual
else:
output = up
return output