Skip to content

Commit

Permalink
allow for cross attending to embeddings of a different dimension than…
Browse files Browse the repository at this point in the history
… the model dimensions of the decoder
  • Loading branch information
lucidrains committed Nov 22, 2023
1 parent e585d33 commit 9de5a5c
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.25.1',
version = '1.25.2',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
10 changes: 7 additions & 3 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,7 @@ def __init__(
self,
dim,
dim_head = DEFAULT_DIM_HEAD,
dim_context = None,
heads = 8,
causal = False,
flash = False,
Expand Down Expand Up @@ -669,6 +670,8 @@ def __init__(
onnxable = False
):
super().__init__()
dim_kv = default(dim_context, dim)

self.scale = dim_head ** -0.5

self.heads = heads
Expand All @@ -691,11 +694,11 @@ def __init__(
out_dim = value_dim_head * heads

self.to_q = nn.Linear(dim, q_dim, bias = False)
self.to_k = nn.Linear(dim, k_dim, bias = False)
self.to_k = nn.Linear(dim_kv, k_dim, bias = False)

# shared key / values, for further memory savings during inference
assert not (shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values'
self.to_v = nn.Linear(dim, v_dim, bias = False) if not shared_kv else None
self.to_v = nn.Linear(dim_kv, v_dim, bias = False) if not shared_kv else None

# relations projection from tp-attention
self.to_r = nn.Linear(dim, v_dim, bias = False) if tensor_product else None
Expand Down Expand Up @@ -1000,6 +1003,7 @@ def __init__(

ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
attn_kwargs, kwargs = groupby_prefix_and_trim('attn_', kwargs)
cross_attn_kwargs, kwargs = groupby_prefix_and_trim('cross_attn_', kwargs)

dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)

Expand Down Expand Up @@ -1144,7 +1148,7 @@ def __init__(
if layer_type == 'a':
layer = Attention(dim, heads = heads, causal = causal, **attn_kwargs)
elif layer_type == 'c':
layer = Attention(dim, heads = heads, **attn_kwargs)
layer = Attention(dim, heads = heads, **{**attn_kwargs, **cross_attn_kwargs})
elif layer_type == 'f':
layer = FeedForward(dim, **ff_kwargs)
layer = layer if not macaron else Scale(0.5, layer)
Expand Down

0 comments on commit 9de5a5c

Please sign in to comment.