Skip to content

Commit

Permalink
allow for rotary embeddings to be constructed outside the attention l…
Browse files Browse the repository at this point in the history
…ayers and passed in
  • Loading branch information
lucidrains committed Nov 22, 2023
1 parent 9de5a5c commit 6220d86
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.25.2',
version = '1.25.3',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
7 changes: 3 additions & 4 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,7 +1191,8 @@ def forward(
seq_start_pos: Optional[Tensor] = None,
cache: Optional[LayerIntermediates] = None,
cache_age = 1,
return_hiddens = False
return_hiddens = False,
rotary_pos_emb = None
):
assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'

Expand Down Expand Up @@ -1219,9 +1220,7 @@ def forward(

# rotary positions

rotary_pos_emb = None

if exists(self.rotary_pos_emb):
if not exists(rotary_pos_emb) and exists(self.rotary_pos_emb):
max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems)))
rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length)

Expand Down

0 comments on commit 6220d86

Please sign in to comment.