Skip to content

Commit

Permalink
turn off kv caching for absolute positional embedding, when exceeding…
Browse files Browse the repository at this point in the history
… the context window during decoding, addressing #219
  • Loading branch information
lucidrains committed Dec 30, 2023
1 parent 324cfac commit a549c55
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 4 deletions.
14 changes: 12 additions & 2 deletions examples/enwik8_simple/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,12 @@ def decode_tokens(tokens):
model = TransformerWrapper(
num_tokens = 256,
max_seq_len = SEQ_LEN,
attn_layers = Decoder(dim = 512, depth = 6, heads = 8)
attn_layers = Decoder(
dim = 512,
depth = 6,
heads = 8,
rotary_pos_emb = True
)
)

model = AutoregressiveWrapper(model)
Expand Down Expand Up @@ -101,6 +106,11 @@ def __len__(self):
prime = decode_tokens(inp)
print(f'%s \n\n %s', (prime, '*' * 100))

sample = model.generate(inp, GENERATE_LENGTH)
sample = model.generate(
prompts = inp,
seq_len = GENERATE_LENGTH,
cache_kv = True
)

output_str = decode_tokens(sample)
print(output_str)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.26.4',
version = '1.26.6',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
4 changes: 4 additions & 0 deletions x_transformers/autoregressive_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,10 @@ def generate(
for _ in range(seq_len):

if restrict_to_max_seq_len:
max_len_exceeded = out.shape[-1] > max_seq_len

assert not (cache_kv and max_len_exceeded and not self.net.can_cache_kv_outside_max_seq_len), 'the network cannot use cached key values when decoding outside the max sequence length. most likely because you are using absolute positional embeeding. you can switch to rotary embeddings to resolve this issue'

x = out[:, -max_seq_len:]

if exists(cache):
Expand Down
5 changes: 4 additions & 1 deletion x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1495,7 +1495,9 @@ def __init__(
self.l2norm_embed = l2norm_embed
self.token_emb = TokenEmbedding(emb_dim, num_tokens, l2norm_embed = l2norm_embed)

if max_seq_len == 0 or not (use_abs_pos_emb and not attn_layers.has_pos_emb):
no_abs_pos_emb = max_seq_len == 0 or not (use_abs_pos_emb and not attn_layers.has_pos_emb)

if no_abs_pos_emb:
self.pos_emb = always(0)
elif scaled_sinu_pos_emb:
self.pos_emb = ScaledSinusoidalEmbedding(emb_dim)
Expand Down Expand Up @@ -1536,6 +1538,7 @@ def __init__(
# whether can do cached kv decoding

self.can_cache_kv = self.num_memory_tokens == 0
self.can_cache_kv_outside_max_seq_len = no_abs_pos_emb

def init_(self):
if self.l2norm_embed:
Expand Down

0 comments on commit a549c55

Please sign in to comment.