From 94b512b9f28a9158d68d162be5746f2c6db499d4 Mon Sep 17 00:00:00 2001 From: UFO-101 Date: Mon, 18 Sep 2023 03:14:17 +0100 Subject: [PATCH] Support freezing key-value caches. --- tests/unit/test_kv_cache.py | 95 ++++++++++++++++++++++ transformer_lens/HookedTransformer.py | 8 +- transformer_lens/past_key_value_caching.py | 16 +++- 3 files changed, 113 insertions(+), 6 deletions(-) diff --git a/tests/unit/test_kv_cache.py b/tests/unit/test_kv_cache.py index d33e20f9c..8992191b5 100644 --- a/tests/unit/test_kv_cache.py +++ b/tests/unit/test_kv_cache.py @@ -80,3 +80,98 @@ def test_multiple_new_tokens(): assert t.allclose( no_cache_logits[:, -new_tokens_len:], with_cache_logits, atol=1e-3 ) + + +def test_freeze_cache(): + past_left_attention_mask = utils.get_attention_mask( + model.tokenizer, + pre_prompt_tokens, + model.cfg.default_prepend_bos, + ) + + post_prompt_1 = " I'm headed to the church to play bingo." + new_tokens_1 = model.to_tokens(post_prompt_1, prepend_bos=False) + full_prompt_tokens_1 = t.cat([pre_prompt_tokens, new_tokens_1], dim=-1) + past_kv_cache_1 = HookedTransformerKeyValueCache.init_cache( + model.cfg, model.cfg.device, pre_prompt_tokens.shape[0] + ) + + post_prompt_2 = " shine your light on me, Miss Liberty" + new_tokens_2 = model.to_tokens(post_prompt_2, prepend_bos=False) + past_kv_cache_2 = HookedTransformerKeyValueCache.init_cache( + model.cfg, model.cfg.device, pre_prompt_tokens.shape[0] + ) + + model( + pre_prompt_tokens, + padding_side=padding_side, + past_kv_cache=past_kv_cache_1, + past_left_attention_mask=None, + ) + past_kv_cache_1.freeze() + with_cache_logits_1 = model( + new_tokens_1, + padding_side=padding_side, + past_kv_cache=past_kv_cache_1, + past_left_attention_mask=past_left_attention_mask, + ) + + model( + pre_prompt_tokens, + padding_side=padding_side, + past_kv_cache=past_kv_cache_2, + past_left_attention_mask=None, + ) + past_kv_cache_2.freeze() + model( + new_tokens_2, + padding_side=padding_side, + past_kv_cache=past_kv_cache_2, + past_left_attention_mask=past_left_attention_mask, + ) + + # Caches frozen at the same point should be identical + assert len(past_kv_cache_1.entries) == len(past_kv_cache_2.entries) + for entry_1, entry_2 in zip(past_kv_cache_1.entries, past_kv_cache_2.entries): + assert entry_1.past_keys.shape == entry_2.past_keys.shape + assert entry_1.past_values.shape == entry_2.past_values.shape + assert t.allclose(entry_1.past_keys, entry_2.past_keys, atol=1e-3) + assert t.allclose(entry_1.past_values, entry_2.past_values, atol=1e-3) + + # Rerunning the same prompt with a different cache that was frozen at the same + # point should give the same results + with_cache_2_logits_1 = model( + new_tokens_1, + padding_side=padding_side, + past_kv_cache=past_kv_cache_2, + past_left_attention_mask=past_left_attention_mask, + ) + assert t.allclose(with_cache_logits_1, with_cache_2_logits_1, atol=1e-3) + + # Test unfreeze + past_kv_cache_2.unfreeze() + with_cache_2_logits_1 = model( + new_tokens_1, + padding_side=padding_side, + past_kv_cache=past_kv_cache_2, + past_left_attention_mask=past_left_attention_mask, + ) + for entry_1, entry_2 in zip(past_kv_cache_1.entries, past_kv_cache_2.entries): + assert entry_1.past_keys.shape[1] < entry_2.past_keys.shape[1] + assert entry_1.past_values.shape[1] < entry_2.past_values.shape[1] + + # Rerunning the same prompt with a different cache should give different + # results + assert t.allclose(with_cache_logits_1, with_cache_2_logits_1, atol=1e-3) + past_left_attention_mask = utils.get_attention_mask( + model.tokenizer, + full_prompt_tokens_1, + model.cfg.default_prepend_bos, + ) + with_cache_2_logits_1 = model( + new_tokens_1, + padding_side=padding_side, + past_kv_cache=past_kv_cache_2, + past_left_attention_mask=past_left_attention_mask, + ) + assert not t.allclose(with_cache_logits_1, with_cache_2_logits_1, atol=1e-3) diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 6dfdc8380..65bb0439d 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -427,10 +427,10 @@ def forward( first transformer block, etc. Supports negative indexing. Useful for analysis of intermediate layers, eg finding neuron activations in layer 3 of a 24 layer model. Defaults to None (run the full model). past_kv_cache Optional[HookedTransformerKeyValueCache]: If not None, keys and values will be stored for every - attention head. If there are keys and values already in the cache, these will be prepended to the keys and values - for the new input, so that the new tokens can pay attention to previous tokens. This is useful for generating text, - because we don't need to repeat computation for tokens that have already been through the model. Defaults to None - (don't use caching). + attention head (unless the cache is frozen). If there are keys and values already in the cache, these will be + prepended to the keys and values for the new input, so that the new tokens can pay attention to previous tokens. + This is useful for generating text, because we don't need to repeat computation for tokens that have already been + through the model. Defaults to None (don't use caching). Note that loss is the standard "predict the next token" cross-entropy loss for GPT-2 style language models - if you want a custom loss function, the recommended behaviour is returning the logits and then applying your diff --git a/transformer_lens/past_key_value_caching.py b/transformer_lens/past_key_value_caching.py index 27fbee287..f0241f0df 100644 --- a/transformer_lens/past_key_value_caching.py +++ b/transformer_lens/past_key_value_caching.py @@ -12,6 +12,7 @@ class HookedTransformerKeyValueCacheEntry: past_keys: Float[torch.Tensor, "batch pos_so_far n_heads d_head"] past_values: Float[torch.Tensor, "batch pos_so_far n_heads d_head"] + frozen: bool = False @classmethod def init_cache_entry( @@ -40,8 +41,9 @@ def append( updated_values: Float[ torch.Tensor, "batch pos_so_far_plus_new_tokens n_heads d_head" ] = torch.cat([self.past_values, new_values], dim=1) - self.past_keys = updated_keys - self.past_values = updated_values + if not self.frozen: + self.past_keys = updated_keys + self.past_values = updated_values return updated_keys, updated_values @@ -51,6 +53,8 @@ class HookedTransformerKeyValueCache: A cache for storing past keys and values for the Transformer. This is important for generating text - we can cache a lot of past computation and avoid repeating ourselves! This cache is a list of HookedTransformerKeyValueCacheEntry objects, one for each layer in the Transformer. Each object stores a [batch, pos_so_far, n_heads, d_head] tensor for both keys and values, and each entry has an append method to add a single new key and value. + + The cache can be frozen so that it is not updated during the forward pass. This is useful when we want to run many inputs with the same prefix. """ entries: List[HookedTransformerKeyValueCacheEntry] @@ -73,5 +77,13 @@ def init_cache( ] ) + def freeze(self): + for entry in self.entries: + entry.frozen = True + + def unfreeze(self): + for entry in self.entries: + entry.frozen = False + def __getitem__(self, idx): return self.entries[idx]