Skip to content

Commit

Permalink
Support freezing key-value caches.
Browse files Browse the repository at this point in the history
  • Loading branch information
UFO-101 committed Sep 18, 2023
1 parent 3b3bfe5 commit 94b512b
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 6 deletions.
95 changes: 95 additions & 0 deletions tests/unit/test_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,98 @@ def test_multiple_new_tokens():
assert t.allclose(
no_cache_logits[:, -new_tokens_len:], with_cache_logits, atol=1e-3
)


def test_freeze_cache():
past_left_attention_mask = utils.get_attention_mask(
model.tokenizer,
pre_prompt_tokens,
model.cfg.default_prepend_bos,
)

post_prompt_1 = " I'm headed to the church to play bingo."
new_tokens_1 = model.to_tokens(post_prompt_1, prepend_bos=False)
full_prompt_tokens_1 = t.cat([pre_prompt_tokens, new_tokens_1], dim=-1)
past_kv_cache_1 = HookedTransformerKeyValueCache.init_cache(
model.cfg, model.cfg.device, pre_prompt_tokens.shape[0]
)

post_prompt_2 = " shine your light on me, Miss Liberty"
new_tokens_2 = model.to_tokens(post_prompt_2, prepend_bos=False)
past_kv_cache_2 = HookedTransformerKeyValueCache.init_cache(
model.cfg, model.cfg.device, pre_prompt_tokens.shape[0]
)

model(
pre_prompt_tokens,
padding_side=padding_side,
past_kv_cache=past_kv_cache_1,
past_left_attention_mask=None,
)
past_kv_cache_1.freeze()
with_cache_logits_1 = model(
new_tokens_1,
padding_side=padding_side,
past_kv_cache=past_kv_cache_1,
past_left_attention_mask=past_left_attention_mask,
)

model(
pre_prompt_tokens,
padding_side=padding_side,
past_kv_cache=past_kv_cache_2,
past_left_attention_mask=None,
)
past_kv_cache_2.freeze()
model(
new_tokens_2,
padding_side=padding_side,
past_kv_cache=past_kv_cache_2,
past_left_attention_mask=past_left_attention_mask,
)

# Caches frozen at the same point should be identical
assert len(past_kv_cache_1.entries) == len(past_kv_cache_2.entries)
for entry_1, entry_2 in zip(past_kv_cache_1.entries, past_kv_cache_2.entries):
assert entry_1.past_keys.shape == entry_2.past_keys.shape
assert entry_1.past_values.shape == entry_2.past_values.shape
assert t.allclose(entry_1.past_keys, entry_2.past_keys, atol=1e-3)
assert t.allclose(entry_1.past_values, entry_2.past_values, atol=1e-3)

# Rerunning the same prompt with a different cache that was frozen at the same
# point should give the same results
with_cache_2_logits_1 = model(
new_tokens_1,
padding_side=padding_side,
past_kv_cache=past_kv_cache_2,
past_left_attention_mask=past_left_attention_mask,
)
assert t.allclose(with_cache_logits_1, with_cache_2_logits_1, atol=1e-3)

# Test unfreeze
past_kv_cache_2.unfreeze()
with_cache_2_logits_1 = model(
new_tokens_1,
padding_side=padding_side,
past_kv_cache=past_kv_cache_2,
past_left_attention_mask=past_left_attention_mask,
)
for entry_1, entry_2 in zip(past_kv_cache_1.entries, past_kv_cache_2.entries):
assert entry_1.past_keys.shape[1] < entry_2.past_keys.shape[1]
assert entry_1.past_values.shape[1] < entry_2.past_values.shape[1]

# Rerunning the same prompt with a different cache should give different
# results
assert t.allclose(with_cache_logits_1, with_cache_2_logits_1, atol=1e-3)
past_left_attention_mask = utils.get_attention_mask(
model.tokenizer,
full_prompt_tokens_1,
model.cfg.default_prepend_bos,
)
with_cache_2_logits_1 = model(
new_tokens_1,
padding_side=padding_side,
past_kv_cache=past_kv_cache_2,
past_left_attention_mask=past_left_attention_mask,
)
assert not t.allclose(with_cache_logits_1, with_cache_2_logits_1, atol=1e-3)
8 changes: 4 additions & 4 deletions transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,10 +427,10 @@ def forward(
first transformer block, etc. Supports negative indexing. Useful for analysis of intermediate layers, eg finding
neuron activations in layer 3 of a 24 layer model. Defaults to None (run the full model).
past_kv_cache Optional[HookedTransformerKeyValueCache]: If not None, keys and values will be stored for every
attention head. If there are keys and values already in the cache, these will be prepended to the keys and values
for the new input, so that the new tokens can pay attention to previous tokens. This is useful for generating text,
because we don't need to repeat computation for tokens that have already been through the model. Defaults to None
(don't use caching).
attention head (unless the cache is frozen). If there are keys and values already in the cache, these will be
prepended to the keys and values for the new input, so that the new tokens can pay attention to previous tokens.
This is useful for generating text, because we don't need to repeat computation for tokens that have already been
through the model. Defaults to None (don't use caching).
Note that loss is the standard "predict the next token" cross-entropy loss for GPT-2 style language models -
if you want a custom loss function, the recommended behaviour is returning the logits and then applying your
Expand Down
16 changes: 14 additions & 2 deletions transformer_lens/past_key_value_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
class HookedTransformerKeyValueCacheEntry:
past_keys: Float[torch.Tensor, "batch pos_so_far n_heads d_head"]
past_values: Float[torch.Tensor, "batch pos_so_far n_heads d_head"]
frozen: bool = False

@classmethod
def init_cache_entry(
Expand Down Expand Up @@ -40,8 +41,9 @@ def append(
updated_values: Float[
torch.Tensor, "batch pos_so_far_plus_new_tokens n_heads d_head"
] = torch.cat([self.past_values, new_values], dim=1)
self.past_keys = updated_keys
self.past_values = updated_values
if not self.frozen:
self.past_keys = updated_keys
self.past_values = updated_values
return updated_keys, updated_values


Expand All @@ -51,6 +53,8 @@ class HookedTransformerKeyValueCache:
A cache for storing past keys and values for the Transformer. This is important for generating text - we can cache a lot of past computation and avoid repeating ourselves!
This cache is a list of HookedTransformerKeyValueCacheEntry objects, one for each layer in the Transformer. Each object stores a [batch, pos_so_far, n_heads, d_head] tensor for both keys and values, and each entry has an append method to add a single new key and value.
The cache can be frozen so that it is not updated during the forward pass. This is useful when we want to run many inputs with the same prefix.
"""

entries: List[HookedTransformerKeyValueCacheEntry]
Expand All @@ -73,5 +77,13 @@ def init_cache(
]
)

def freeze(self):
for entry in self.entries:
entry.frozen = True

def unfreeze(self):
for entry in self.entries:
entry.frozen = False

def __getitem__(self, idx):
return self.entries[idx]

0 comments on commit 94b512b

Please sign in to comment.