Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix the bug that attention_mask and past_kv_cache cannot work together #772

Open
wants to merge 3 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions tests/integration/test_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,28 @@ def test_freeze_cache(pretrained):
assert not t.allclose(with_cache_logits_1, with_cache_2_logits_1, atol=atol)


def test_kv_cache_with_custom_attention_mask(pretrained):
model, atol = pretrained
prompt_pre = "An apple"
prompt_post = " a day keeps junk the"
prompt_whole = "An apple a day keeps the"
tokens_pre = model.to_tokens(prompt_pre)
tokens_post = model.to_tokens(prompt_post, prepend_bos=False)
tokens_whole = model.to_tokens(prompt_whole)
correct_logits = model(tokens_whole)

past_kv_cache = HookedTransformerKeyValueCache.init_cache(
model.cfg, model.cfg.device, tokens_pre.shape[0]
)
model(tokens_pre, past_kv_cache=past_kv_cache)
exp_logits = model(
tokens_post,
attention_mask=t.tensor([[1, 1, 1, 0, 1]], device=model.cfg.device),
past_kv_cache=past_kv_cache,
)
assert t.allclose(correct_logits[:, -1], exp_logits[:, -1], atol=atol)


def test_kv_cache_and_start_at_layer(pretrained):
model, atol = pretrained
pre_prompt = "I went to Staten Island,"
Expand Down
27 changes: 15 additions & 12 deletions transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
alteration of activations in individual components like attention heads and MLP layers, facilitating
a deeper understanding of the internal workings of transformers like GPT-2.
"""

import logging
import os
from typing import (
Expand Down Expand Up @@ -297,23 +298,25 @@ def input_to_embed(
if tokens.device.type != self.cfg.device:
tokens = tokens.to(devices.get_device_for_block_index(0, self.cfg))

if attention_mask is not None:
if (
(self.tokenizer and self.tokenizer.padding_side == "left")
or attention_mask is not None
or past_kv_cache is not None
):
# This means we need to have an explicit attention mask.
if attention_mask is None:
# If the padding side is left or we are using caching, we need to compute the attention
# mask for the adjustment of absolute positional embeddings and attention masking so
# that pad tokens are not attended.
if prepend_bos is USE_DEFAULT_VALUE:
prepend_bos = self.cfg.default_prepend_bos
attention_mask = utils.get_attention_mask(self.tokenizer, tokens, prepend_bos)

assert attention_mask.shape == tokens.shape, (
f"Attention mask shape {attention_mask.shape} does not match tokens shape "
f"{tokens.shape}"
)
attention_mask = attention_mask.to(devices.get_device_for_block_index(0, self.cfg))
elif (
self.tokenizer and self.tokenizer.padding_side == "left"
) or past_kv_cache is not None:
# If the padding side is left or we are using caching, we need to compute the attention
# mask for the adjustment of absolute positional embeddings and attention masking so
# that pad tokens are not attended.

if prepend_bos is USE_DEFAULT_VALUE:
prepend_bos = self.cfg.default_prepend_bos
attention_mask = utils.get_attention_mask(self.tokenizer, tokens, prepend_bos)

if past_kv_cache is not None:
# past_kv_cache is not None, so we're doing caching.
# We need to extend the previous attention_mask.
Expand Down
Loading