Skip to content

Commit

Permalink
- Support passing in multiple tokens when using past_kv_cache.
Browse files Browse the repository at this point in the history
- Add tests for past_kv_cache.
- Add documentation for past_kv_cache.
- Fix type hints for some components that assume left_attention_mask has same number of tokens as input. This was previously unnoticed because there were no tests that covered past_kv_cache.
  • Loading branch information
UFO-101 committed Sep 18, 2023
1 parent 20a44fe commit 9b0e81d
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 18 deletions.
82 changes: 82 additions & 0 deletions tests/unit/test_kv_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# %%
import torch as t

from transformer_lens import HookedTransformer, utils
from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCache

MODEL = "solu-1l"
model = HookedTransformer.from_pretrained(MODEL)

pre_prompt = "I went to Staten Island,"
padding_side = "left"
prepend_bos = True
pre_prompt_tokens = model.to_tokens(
pre_prompt, prepend_bos=prepend_bos, padding_side=padding_side
)


def test_single_new_token():
post_prompt = " Sharon"
new_token = model.to_tokens(post_prompt, prepend_bos=False)
full_prompt_tokens = t.cat([pre_prompt_tokens, new_token], dim=-1)
assert full_prompt_tokens.shape[-1] == pre_prompt_tokens.shape[-1] + 1
no_cache_logits = model(full_prompt_tokens, padding_side=padding_side)

past_kv_cache = HookedTransformerKeyValueCache.init_cache(
model.cfg, model.cfg.device, pre_prompt_tokens.shape[0]
)
model(
pre_prompt_tokens,
padding_side=padding_side,
past_kv_cache=past_kv_cache,
past_left_attention_mask=None,
)
past_left_attention_mask = utils.get_attention_mask(
model.tokenizer,
pre_prompt_tokens,
model.cfg.default_prepend_bos,
)
with_cache_logits = model(
new_token,
padding_side=padding_side,
past_kv_cache=past_kv_cache,
past_left_attention_mask=past_left_attention_mask,
)
print("no_cache_logits", no_cache_logits[:, -1])
print("with_cache_logits", with_cache_logits[:, -1])
assert t.allclose(no_cache_logits[:, -1], with_cache_logits[:, -1], atol=1e-3)
assert t.allclose(no_cache_logits[:, -1:], with_cache_logits, atol=1e-3)


def test_multiple_new_tokens():
post_prompt = " to buy myself a mandolin"
new_tokens = model.to_tokens(post_prompt, prepend_bos=False)
new_tokens_len = new_tokens.shape[-1]
full_prompt_tokens = t.cat([pre_prompt_tokens, new_tokens], dim=-1)
assert full_prompt_tokens.shape[-1] == pre_prompt_tokens.shape[-1] + new_tokens_len
no_cache_logits = model(full_prompt_tokens, padding_side=padding_side)

past_kv_cache = HookedTransformerKeyValueCache.init_cache(
model.cfg, model.cfg.device, pre_prompt_tokens.shape[0]
)
model(
pre_prompt_tokens,
padding_side=padding_side,
past_kv_cache=past_kv_cache,
past_left_attention_mask=None,
)
past_left_attention_mask = utils.get_attention_mask(
model.tokenizer,
pre_prompt_tokens,
model.cfg.default_prepend_bos,
)
with_cache_logits = model(
new_tokens,
padding_side=padding_side,
past_kv_cache=past_kv_cache,
past_left_attention_mask=past_left_attention_mask,
)
assert t.allclose(no_cache_logits[:, -1], with_cache_logits[:, -1], atol=1e-3)
assert t.allclose(
no_cache_logits[:, -new_tokens_len:], with_cache_logits, atol=1e-3
)
19 changes: 10 additions & 9 deletions transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,14 +229,11 @@ def input_to_embed(
assert (
past_kv_cache is not None
), "If past_left_attention_mask is not None, past_kv_cache must not be None"
assert (
tokens.shape[1] == 1
), "If past_left_attention_mask is not None, tokens must be a single token along the sequence dimension"
# past_kv_cache is not None, so we're doing caching.
# We need to extend the past_left_attention_mask.
# Append '1' to the right of the past_left_attention_mask to account for the new tokens
# Append '1's to the right of the past_left_attention_mask to account for each new token.
left_attention_mask = utils.extend_tensor_with_ones(
past_left_attention_mask
past_left_attention_mask, num_elements=tokens.shape[1]
)

else:
Expand Down Expand Up @@ -264,10 +261,6 @@ def input_to_embed(
assert cached_batch_size == batch_size
assert num_heads_in_cache == self.cfg.n_heads
assert d_head_in_cache == self.cfg.d_head
# If we want to generate from the empty string, we'd pass in an empty cache, so we need to handle that case
assert (
cache_ctx_length == 0 or ctx_length == 1
), "Pass in one token at a time after loading cache"
pos_offset = cache_ctx_length
if self.cfg.use_hook_tokens:
tokens = self.hook_tokens(tokens)
Expand Down Expand Up @@ -433,6 +426,11 @@ def forward(
stop_at_layer = 0 will only run the embedding layer, stop_at_layer = 1 will run the embedding layer and the
first transformer block, etc. Supports negative indexing. Useful for analysis of intermediate layers, eg finding
neuron activations in layer 3 of a 24 layer model. Defaults to None (run the full model).
past_kv_cache Optional[HookedTransformerKeyValueCache]: If not None, keys and values will be stored for every
attention head. If there are keys and values already in the cache, these will be prepended to the keys and values
for the new input, so that the new tokens can pay attention to previous tokens. This is useful for generating text,
because we don't need to repeat computation for tokens that have already been through the model. Defaults to None
(don't use caching).
Note that loss is the standard "predict the next token" cross-entropy loss for GPT-2 style language models -
if you want a custom loss function, the recommended behaviour is returning the logits and then applying your
Expand Down Expand Up @@ -499,6 +497,9 @@ def forward(
if return_type == "logits":
return logits
else:
assert (
tokens is not None
), "tokens must be passed in if return_type is 'loss' or 'both'"
loss = self.loss_fn(logits, tokens, per_token=loss_per_token)
if return_type == "loss":
return loss
Expand Down
8 changes: 4 additions & 4 deletions transformer_lens/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def forward(
self,
tokens: Int[torch.Tensor, "batch pos"],
past_kv_pos_offset: int = 0,
left_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
left_attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
) -> Float[torch.Tensor, "batch pos d_model"]:
"""
Forward pass for positional embeddings.
Expand Down Expand Up @@ -532,7 +532,7 @@ def forward(
],
past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None,
additive_attention_mask: Optional[Float[torch.Tensor, "batch 1 1 pos"]] = None,
left_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
left_attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
) -> Float[torch.Tensor, "batch pos d_model"]:
"""
shortformer_pos_embed is only used if self.cfg.positional_embedding_type == "shortformer", else defaults to None and is irrelevant. See HookedTransformerConfig for more details
Expand Down Expand Up @@ -661,7 +661,7 @@ def apply_causal_mask(
torch.Tensor, "batch head_index pos pos_plus_past_kv_pos_offset"
],
past_kv_pos_offset: int = 0,
left_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
left_attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
):
# The query context length is the number of positions we take queries from - if not using a past_kv_cache this is just the context length (for the current prompt), but if we're caching it's just a single token.
query_ctx_length = attn_scores.size(-2)
Expand Down Expand Up @@ -1001,7 +1001,7 @@ def forward(
Float[torch.Tensor, "batch pos d_model"]
] = None,
past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None,
left_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
left_attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
) -> Float[torch.Tensor, "batch pos d_model"]:
"""A single Transformer block.
Expand Down
3 changes: 0 additions & 3 deletions transformer_lens/past_key_value_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,6 @@ class HookedTransformerKeyValueCache:
A cache for storing past keys and values for the Transformer. This is important for generating text - we can cache a lot of past computation and avoid repeating ourselves!
This cache is a list of HookedTransformerKeyValueCacheEntry objects, one for each layer in the Transformer. Each object stores a [batch, pos_so_far, n_heads, d_head] tensor for both keys and values, and each entry has an append method to add a single new key and value.
Generation is assumed to be done by initializing with some prompt and then continuing iteratively one token at a time. So append only works for adding a single token's worth of keys and values, and but the cache can be initialized with many.
"""

entries: List[HookedTransformerKeyValueCacheEntry]
Expand Down
4 changes: 2 additions & 2 deletions transformer_lens/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,8 +1037,8 @@ def __exit__(self, exc_type, exc_val, exc_tb):
set_nested_attr(self, default_location, default_value)


def extend_tensor_with_ones(tensor, dim=1):
def extend_tensor_with_ones(tensor, dim=1, num_elements=1):
new_elements = torch.ones(
(tensor.shape[0], 1), dtype=tensor.dtype, device=tensor.device
(tensor.shape[0], num_elements), dtype=tensor.dtype, device=tensor.device
)
return torch.cat([tensor, new_elements], dim=dim)

0 comments on commit 9b0e81d

Please sign in to comment.