Skip to content

Commit

Permalink
add an improvised per row data dependent alibi, using attention-esque…
Browse files Browse the repository at this point in the history
… forget queries and keys
  • Loading branch information
lucidrains committed Oct 29, 2024
1 parent e2bab6d commit dfd5c8b
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 5 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.41.1',
version = '1.41.2',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
53 changes: 49 additions & 4 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,46 @@ def forward(self, x):
forget_gates = einx.subtract('b h i, b h j -> b h i j', forget_gates, forget_gates)
return forget_gates

class PerRowDataDependentAlibi(Module):
""" same as data dependent alibi from forgetting transformer, but the forgetting gates are also derived by a queris and keys with a small head dimension """

def __init__(
self,
dim,
heads,
dim_head = 8
):
super().__init__()
self.scale = dim_head ** -0.5

linear = nn.Linear(dim, heads * dim_head * 2, bias = False)

self.to_forget_gates = nn.Sequential(
linear,
Rearrange('b n (kv h d) -> kv b h n d', kv = 2, d = dim_head)
)

def forward(self, x):
q, k = self.to_forget_gates(x)
forget_gates = einsum('... i d, ... j d -> ... i j', q, k) * self.scale

forget_gates = F.logsigmoid(forget_gates)

# mask out upper triangle + diagonal

n = x.shape[-2]
causal_mask = torch.ones((n, n), dtype = torch.bool, device = x.device).triu()

forget_gates = forget_gates.masked_fill(causal_mask, 0.)

# reverse cumsum

forget_gates = forget_gates.flip(dims = (-1,))
forget_gates = forget_gates.cumsum(dim = -1)
forget_gates = forget_gates.flip(dims = (-1,))

return forget_gates

class RotaryEmbedding(Module):
def __init__(
self,
Expand Down Expand Up @@ -968,6 +1008,8 @@ def __init__(
add_zero_kv = False, # same as add_zero_attn in pytorch
rotary_embed_values = False,
data_dependent_alibi = False,
data_dependent_alibi_per_row = False,
data_dependent_alibi_per_row_dim_head = 8,
use_cope = False,
cope_max_pos = 16,
cope_soft_onehot_pos = False,
Expand Down Expand Up @@ -1079,10 +1121,13 @@ def __init__(
if data_dependent_alibi:
assert causal, 'data dependent alibi only works for autoregressive for now until further research'

self.data_dependent_alibi = DataDependentAlibi(
dim,
heads = heads
)
dda_klass = DataDependentAlibi if not data_dependent_alibi_per_row else PerRowDataDependentAlibi
dda_kwargs = dict(dim = dim, heads = heads)

if data_dependent_alibi_per_row:
dda_kwargs.update(dim_head = data_dependent_alibi_per_row_dim_head)

self.data_dependent_alibi = dda_klass(**dda_kwargs)

# attend class - includes core attention algorithm + talking heads

Expand Down

0 comments on commit dfd5c8b

Please sign in to comment.