diff --git a/setup.py b/setup.py index 7e0780e..f68137d 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'x-transformers', packages = find_packages(exclude=['examples']), - version = '1.41.1', + version = '1.41.2', license='MIT', description = 'X-Transformers - Pytorch', author = 'Phil Wang', diff --git a/x_transformers/x_transformers.py b/x_transformers/x_transformers.py index cc80f6d..1c26a0b 100644 --- a/x_transformers/x_transformers.py +++ b/x_transformers/x_transformers.py @@ -519,6 +519,46 @@ def forward(self, x): forget_gates = einx.subtract('b h i, b h j -> b h i j', forget_gates, forget_gates) return forget_gates +class PerRowDataDependentAlibi(Module): + """ same as data dependent alibi from forgetting transformer, but the forgetting gates are also derived by a queries and keys with a small head dimension """ + + def __init__( + self, + dim, + heads, + dim_head = 8 + ): + super().__init__() + self.scale = dim_head ** -0.5 + + linear = nn.Linear(dim, heads * dim_head * 2, bias = False) + + self.to_forget_gates = nn.Sequential( + linear, + Rearrange('b n (kv h d) -> kv b h n d', kv = 2, d = dim_head) + ) + + def forward(self, x): + q, k = self.to_forget_gates(x) + forget_gates = einsum('... i d, ... j d -> ... i j', q, k) * self.scale + + forget_gates = F.logsigmoid(forget_gates) + + # mask out upper triangle + diagonal + + n = x.shape[-2] + causal_mask = torch.ones((n, n), dtype = torch.bool, device = x.device).triu() + + forget_gates = forget_gates.masked_fill(causal_mask, 0.) + + # reverse cumsum + + forget_gates = forget_gates.flip(dims = (-1,)) + forget_gates = forget_gates.cumsum(dim = -1) + forget_gates = forget_gates.flip(dims = (-1,)) + + return forget_gates + class RotaryEmbedding(Module): def __init__( self, @@ -968,6 +1008,8 @@ def __init__( add_zero_kv = False, # same as add_zero_attn in pytorch rotary_embed_values = False, data_dependent_alibi = False, + data_dependent_alibi_per_row = False, + data_dependent_alibi_per_row_dim_head = 8, use_cope = False, cope_max_pos = 16, cope_soft_onehot_pos = False, @@ -1079,10 +1121,13 @@ def __init__( if data_dependent_alibi: assert causal, 'data dependent alibi only works for autoregressive for now until further research' - self.data_dependent_alibi = DataDependentAlibi( - dim, - heads = heads - ) + dda_klass = DataDependentAlibi if not data_dependent_alibi_per_row else PerRowDataDependentAlibi + dda_kwargs = dict(dim = dim, heads = heads) + + if data_dependent_alibi_per_row: + dda_kwargs.update(dim_head = data_dependent_alibi_per_row_dim_head) + + self.data_dependent_alibi = dda_klass(**dda_kwargs) # attend class - includes core attention algorithm + talking heads