diff --git a/tests/integration/test_kv_cache.py b/tests/integration/test_kv_cache.py index a98ba7de6..baab6696a 100644 --- a/tests/integration/test_kv_cache.py +++ b/tests/integration/test_kv_cache.py @@ -213,6 +213,28 @@ def test_freeze_cache(pretrained): assert not t.allclose(with_cache_logits_1, with_cache_2_logits_1, atol=atol) +def test_kv_cache_with_custom_attention_mask(pretrained): + model, atol = pretrained + prompt_pre = "An apple" + prompt_post = " a day keeps junk the" + prompt_whole = "An apple a day keeps the" + tokens_pre = model.to_tokens(prompt_pre) + tokens_post = model.to_tokens(prompt_post, prepend_bos=False) + tokens_whole = model.to_tokens(prompt_whole) + correct_logits = model(tokens_whole) + + past_kv_cache = HookedTransformerKeyValueCache.init_cache( + model.cfg, model.cfg.device, tokens_pre.shape[0] + ) + model(tokens_pre, past_kv_cache=past_kv_cache) + exp_logits = model( + tokens_post, + attention_mask=t.tensor([[1, 1, 1, 0, 1]], device=model.cfg.device), + past_kv_cache=past_kv_cache, + ) + assert t.allclose(correct_logits[:, -1], exp_logits[:, -1], atol=atol) + + def test_kv_cache_and_start_at_layer(pretrained): model, atol = pretrained pre_prompt = "I went to Staten Island," diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 56096484c..c084c02b1 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -8,6 +8,7 @@ alteration of activations in individual components like attention heads and MLP layers, facilitating a deeper understanding of the internal workings of transformers like GPT-2. """ + import logging import os from typing import ( @@ -297,23 +298,25 @@ def input_to_embed( if tokens.device.type != self.cfg.device: tokens = tokens.to(devices.get_device_for_block_index(0, self.cfg)) - if attention_mask is not None: + if ( + (self.tokenizer and self.tokenizer.padding_side == "left") + or attention_mask is not None + or past_kv_cache is not None + ): + # This means we need to have an explicit attention mask. + if attention_mask is None: + # If the padding side is left or we are using caching, we need to compute the attention + # mask for the adjustment of absolute positional embeddings and attention masking so + # that pad tokens are not attended. + if prepend_bos is USE_DEFAULT_VALUE: + prepend_bos = self.cfg.default_prepend_bos + attention_mask = utils.get_attention_mask(self.tokenizer, tokens, prepend_bos) + assert attention_mask.shape == tokens.shape, ( f"Attention mask shape {attention_mask.shape} does not match tokens shape " f"{tokens.shape}" ) attention_mask = attention_mask.to(devices.get_device_for_block_index(0, self.cfg)) - elif ( - self.tokenizer and self.tokenizer.padding_side == "left" - ) or past_kv_cache is not None: - # If the padding side is left or we are using caching, we need to compute the attention - # mask for the adjustment of absolute positional embeddings and attention masking so - # that pad tokens are not attended. - - if prepend_bos is USE_DEFAULT_VALUE: - prepend_bos = self.cfg.default_prepend_bos - attention_mask = utils.get_attention_mask(self.tokenizer, tokens, prepend_bos) - if past_kv_cache is not None: # past_kv_cache is not None, so we're doing caching. # We need to extend the previous attention_mask.