Skip to content

Commit

Permalink
show an example for multiple input embeddings, and also allow transfo…
Browse files Browse the repository at this point in the history
…rmerwrapper not to have logits
  • Loading branch information
lucidrains committed Aug 6, 2024
1 parent 1c45c72 commit cc436b3
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 3 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.32.2',
version = '1.32.4',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
29 changes: 28 additions & 1 deletion tests/test_x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,31 @@ def test_attn_softclamp_logits():

x = torch.randint(0, 256, (1, 1024))

model(x)
model(x)

def test_multiple_input_embeds():
model = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
return_only_embed = True,
embed_num_tokens = dict(
pitch = 32,
tone = 16
),
attn_layers = Decoder(
dim = 128,
depth = 6,
heads = 8
)
)

x = torch.randint(0, 20000, (2, 1024))

embed_ids = dict(
pitch = torch.randint(0, 32, (2, 1024)),
tone = torch.randint(0, 16, (2, 1024))
)

embed = model(x, embed_ids = embed_ids)

assert embed.shape == (2, 1024, 128)
11 changes: 10 additions & 1 deletion x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1899,6 +1899,7 @@ def __init__(
memory_tokens_interspersed_every = None,
tie_embedding = False,
logits_dim = None,
return_only_embed = False,
num_output_heads = 1,
use_abs_pos_emb = True,
scaled_sinu_pos_emb = False,
Expand Down Expand Up @@ -1948,13 +1949,17 @@ def __init__(

self.init_()

assert num_output_heads > 0

# output head, usually to logits of num_tokens

logits_dim = default(logits_dim, num_tokens)

self.has_multiple_heads = False

if tie_embedding:
if return_only_embed:
self.to_logits = None
elif tie_embedding:
self.to_logits = lambda t: t @ self.token_emb.emb.weight.t()
elif num_output_heads > 1:
self.has_multiple_heads = True
Expand Down Expand Up @@ -2008,7 +2013,9 @@ def forward(
**kwargs
):
b, n, device, num_mems, has_memory_tokens, emb_frac_gradient = x.shape[0], x.shape[1], x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient

return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss
return_embeddings = return_embeddings | (not exists(self.to_logits))

# absolute positional embedding

Expand All @@ -2018,6 +2025,8 @@ def forward(

# add additional embeddings

assert not (exists(self.embeds) ^ (len(embed_ids) > 0)), '`embed_num_tokens` must be defined on `TransformerWrapper`'

if exists(self.embeds):
assert len(embed_ids) == len(self.embeds)

Expand Down

0 comments on commit cc436b3

Please sign in to comment.