Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for left padding #344

Merged
merged 25 commits into from
Sep 1, 2023
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
dd2db4a
feat(*): add support for left padding
soheeyang Jul 11, 2023
4ce1cde
feat(*): add tests
soheeyang Jul 15, 2023
3eb577c
chore(*): reformat the code
soheeyang Jul 15, 2023
0c62d0a
Merge branch 'main' of https://github.com/neelnanda-io/TransformerLen…
soheeyang Jul 15, 2023
5e041c0
fix(*): fix minor runtime errors and reformat the code
soheeyang Jul 15, 2023
63e70e9
feat(demo): add info of left padding in exploratory demo
soheeyang Jul 15, 2023
c17d860
feat(demo): add missing commit
soheeyang Jul 16, 2023
b76c054
fix(PosEmbed): remove deprecated code
soheeyang Jul 16, 2023
bf5686f
fix(test_left_padding): add atol to allclose
soheeyang Jul 16, 2023
1bba99a
feat(*): enable overriding padding_side
soheeyang Jul 24, 2023
1d2d2d5
refactor(*): change method and attribute names
soheeyang Jul 26, 2023
c3c6f0f
refactor(*): change None to USE_DEFAULT_VALUE
soheeyang Jul 26, 2023
fd3ba96
Merge branch 'main' of https://github.com/neelnanda-io/TransformerLen…
soheeyang Jul 26, 2023
9283bfc
refactor(HookedTransformer): remove unnecessary function calls
soheeyang Jul 26, 2023
1591711
Merge branch 'main' of https://github.com/soheeyang/TransformerLens i…
soheeyang Jul 30, 2023
0591924
fix(get_attention_mask): fix and add tests for sep_token == pad_token…
soheeyang Aug 13, 2023
c8cb751
feat(*): change decorator to LocallyOverridenDefaults context manager
soheeyang Aug 13, 2023
6899293
chore(*): change typing and reformat the code
soheeyang Aug 13, 2023
ab5392a
Merge branch 'main' of https://github.com/neelnanda-io/TransformerLen…
soheeyang Aug 14, 2023
6c5ce31
fix(HookedTransformer): fix type check error
soheeyang Aug 14, 2023
01e91ac
chore(get_attention_mask): add more comments
soheeyang Aug 14, 2023
d43b18f
Merge branch 'main' of https://github.com/neelnanda-io/TransformerLen…
soheeyang Aug 23, 2023
d07856c
fix(apply_causal_mask): fix error on multi gpu setting
soheeyang Aug 23, 2023
b47a669
chore(*): reformat the code
soheeyang Aug 28, 2023
f4bc471
fix(to_tokens): add missing commit during merge
soheeyang Aug 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions demos/Exploratory_Analysis_Demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,11 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"**Gotcha**: It's important that all of your prompts have the same number of tokens. If they're different lengths, then the position of the \"final\" logit where you can check logit difference will differ between prompts, and this will break the below code. The easiest solution is just to choose your prompts carefully to have the same number of tokens (you can eg add filler words like The, or newlines to start). There's a range of other ways of solving this, eg you can also index more intelligently to get the final logit"
"**Gotcha**: It's important that all of your prompts have the same number of tokens. If they're different lengths, then the position of the \"final\" logit where you can check logit difference will differ between prompts, and this will break the below code. The easiest solution is just to choose your prompts carefully to have the same number of tokens (you can eg add filler words like The, or newlines to start).\n",
"\n",
"There's a range of other ways of solving this, eg you can index more intelligently to get the final logit. A better way is to just use left padding by setting `model.tokenizer.padding_side = 'left'` before tokenizing the inputs and running the model; this way, you can use something like `logits[:, -1, :]` to easily access the final token outputs without complicated indexing. TransformerLens checks the value of `padding_side` of the tokenizer internally, and if the flag is set to be `'left'`, it adjusts the calculation of absolute position embedding and causal masking accordingly.\n",
"\n",
"In this demo, though, we stick to using the prompts of the same number of tokens because we want to show some visualisations aggregated along the batch dimension later in the demo."
]
},
{
Expand Down Expand Up @@ -2921,7 +2925,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.14"
"version": "3.9.17"
},
"vscode": {
"interpreter": {
Expand Down
170 changes: 170 additions & 0 deletions tests/unit/test_left_padding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import pytest
import torch

from transformer_lens import HookedTransformer, utils


class TestLeftPadding:
prompts = [
"Hello world!",
"How are you today?",
"I'm fine, thank you.",
"I am happy.",
]

# helpers
def check_outputs_identity(
self,
i,
single_outputs,
left_outputs,
right_outputs,
left_token_start,
left_token_end,
right_token_start,
right_token_end,
):
atol = 1e-4

assert torch.allclose(
left_outputs[i, left_token_start:left_token_end, :],
right_outputs[i, right_token_start:right_token_end, :],
atol=atol,
)

assert torch.allclose(
left_outputs[i, left_token_start:left_token_end, :],
single_outputs[0],
atol=atol,
)

assert torch.allclose(
right_outputs[i, right_token_start:right_token_end, :],
single_outputs[0],
atol=atol,
)

# fixtures
@pytest.fixture(scope="class", params=["gpt2-small", "facebook/opt-125m"])
def model_name(self, request):
return request.param

@pytest.fixture(scope="class")
def model(self, model_name):
model = HookedTransformer.from_pretrained(model_name)
return model

# tests
@pytest.mark.parametrize("padding_side", ["left", "right"])
@pytest.mark.parametrize("prepend_bos", [True, False])
def test_pos_embed(self, model, padding_side, prepend_bos):
# setup
model.tokenizer.padding_side = padding_side

prompts = self.prompts
tokens = model.to_tokens(prompts, prepend_bos=prepend_bos)
str_tokens = model.to_str_tokens(prompts, prepend_bos=prepend_bos)

left_attention_mask = utils.get_attention_mask(
model.tokenizer, tokens, prepend_bos
) # [batch pos]

output_pos_embed = model.pos_embed(
tokens, 0, left_attention_mask=left_attention_mask
) # [batch pos d_model]

# check if the output pos_embeds have the correct shape
assert output_pos_embed.shape == (
tokens.shape[0],
tokens.shape[1],
model.pos_embed.W_pos.shape[1],
)

# check if the target pos_embeds are the same as the output pos_embeds
target_position_ids = torch.tensor(
sum([list(range(len(t))) for t in str_tokens], []), device=tokens.device
)
target_output_pos_embed = model.pos_embed.W_pos[target_position_ids, :]

attended_output_pos_embed = output_pos_embed[left_attention_mask.bool()]

assert torch.allclose(
attended_output_pos_embed, target_output_pos_embed, atol=1e-4
)

# padded positions should have zero pos_embed
assert output_pos_embed[~left_attention_mask.bool()].sum() == 0

def test_left_padding_by_comparing_outputs(self, model):
prompts = self.prompts

num_str_tokens_list = [len(t) for t in model.to_str_tokens(prompts)]

# left padding output
model.tokenizer.padding_side = "left"
left_logits, left_cache = model.run_with_cache(prompts)
left_last_logits = left_logits[:, -1, :]
left_first_token_positions = left_logits.shape[1] - torch.tensor(
num_str_tokens_list, device=left_logits.device
)
left_first_logits = left_logits[
torch.arange(len(prompts)), left_first_token_positions, :
].squeeze(1)

# right padding output
model.tokenizer.padding_side = "right"
right_logits, right_cache = model.run_with_cache(prompts)
right_last_token_positions = (
torch.tensor(num_str_tokens_list, device=right_logits.device) - 1
)
right_last_logits = right_logits[
torch.arange(len(prompts)), right_last_token_positions, :
].squeeze(1)
right_first_logits = right_logits[:, 0, :]

# check if the left and right padding outputs are the same for the first and last tokens
assert torch.allclose(left_last_logits, right_last_logits, atol=1e-4)
assert torch.allclose(left_first_logits, right_first_logits, atol=1e-4)

# check if the left and right padding outputs are the same for all tokens
# and if the batched padded outputs are the same as the single prompt outputs
right_token_start = 0
left_token_end = left_logits.shape[1]
for i, (prompt, left_token_start, right_token_end) in enumerate(
zip(
prompts,
left_first_token_positions.tolist(),
(right_last_token_positions + 1).tolist(),
)
):
single_logits, single_cache = model.run_with_cache(prompt)

assert (
right_token_end - right_token_start
== left_token_end - left_token_start
== single_logits.shape[1]
)

self.check_outputs_identity(
i,
single_logits,
left_logits,
right_logits,
left_token_start,
left_token_end,
right_token_start,
right_token_end,
)

# check cache
for name in ["k6a", "pre2", "embed", "k6", "scale4ln1", "pre5"]:
self.check_outputs_identity(
i,
single_cache[name],
left_cache[name],
right_cache[name],
left_token_start,
left_token_end,
right_token_start,
right_token_end,
)
123 changes: 115 additions & 8 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,21 +209,128 @@ def test_fail(self, x: torch.Tensor):

def test_override_or_use_default_flag():
# Case when override is not None
assert utils.override_or_use_default_flag(default_flag=True, override=True) == True
assert utils.override_or_use_default_value(default_flag=True, override=True) == True
assert (
utils.override_or_use_default_flag(default_flag=True, override=False) == False
utils.override_or_use_default_value(default_flag=True, override=False) == False
)
assert utils.override_or_use_default_flag(default_flag=False, override=True) == True
assert (
utils.override_or_use_default_flag(default_flag=False, override=False) == False
utils.override_or_use_default_value(default_flag=False, override=True) == True
)
assert (
utils.override_or_use_default_value(default_flag=False, override=False) == False
)

# Case when override is None
assert utils.override_or_use_default_flag(default_flag=True, override=None) == True
assert utils.override_or_use_default_value(default_flag=True, override=None) == True
assert (
utils.override_or_use_default_flag(default_flag=False, override=None) == False
utils.override_or_use_default_value(default_flag=False, override=None) == False
)

# Case when override is not passed
assert utils.override_or_use_default_flag(default_flag=True) == True
assert utils.override_or_use_default_flag(default_flag=False) == False
assert utils.override_or_use_default_value(default_flag=True) == True
assert utils.override_or_use_default_value(default_flag=False) == False


class TestAttentionMask:
prompts = [
"Hello world!",
"How are you today?",
"I'm fine, thank you.",
"I am happy.",
]

prompts_with_sep = [
"I like cats<|endoftext|>Cats are so cute",
"Hello world!",
"How are you<|endoftext|>I am fine, thanks",
]

# fixtures
@pytest.fixture(scope="class", params=["gpt2-small", "facebook/opt-125m"])
def model_name(self, request):
return request.param

@pytest.fixture(scope="class")
def model(self, model_name):
return HookedTransformer.from_pretrained(model_name)

# tests
@pytest.mark.parametrize("padding_side", ["left", "right"])
@pytest.mark.parametrize("prepend_bos", [True, False])
@pytest.mark.parametrize("prompts_with_sep", [True, False])
def test_get_attention_mask(
self, model, padding_side, prepend_bos, prompts_with_sep
):
# setup
model.tokenizer.padding_side = padding_side
model.tokenizer.sep_token_id = model.tokenizer.pad_token_id
prepend_bos = prepend_bos

prompts = self.prompts_with_sep if prompts_with_sep else self.prompts
tokens = model.to_tokens(prompts, prepend_bos=prepend_bos)

attention_mask = utils.get_attention_mask(
model.tokenizer, tokens, prepend_bos=prepend_bos
) # [batch pos]

# dimension should be the same
assert attention_mask.shape == tokens.shape

# number of attended tokens for each sequence
# should be the same as the number of 1s in the attention mask for that sequence
str_tokens = model.to_str_tokens(prompts, prepend_bos=prepend_bos)
intended_num_attended_tokens = torch.tensor(
[len(t) for t in str_tokens], device=attention_mask.device
)
assert (intended_num_attended_tokens == attention_mask.sum(dim=1)).all()

# all the masked tokens should be the padding token
assert (tokens[attention_mask == 0] == model.tokenizer.pad_token_id).all()

if padding_side == "right":
# the first token is always attended
assert (attention_mask[:, 0] == 1).all()

# attended tokens are at the beginning of the sequence
for i, num in enumerate(intended_num_attended_tokens.tolist()):
assert (attention_mask[i, 0:num] == 1).all()

else: # left padding case
# the last token is always attended
assert (attention_mask[:, -1] == 1).all()

# attended tokens are at the end of the sequence
for i, num in enumerate(intended_num_attended_tokens.tolist()):
assert (attention_mask[i, -num:] == 1).all()

# the following tests make sense only when the prompts do not contain the separator token
if not prompts_with_sep:
non_pad_token_mask = (tokens != model.tokenizer.pad_token_id).int()
attended_but_non_pad_mask = attention_mask != non_pad_token_mask
if model.tokenizer.bos_token == model.tokenizer.pad_token and prepend_bos:
# if bos_token is the same as pad_token and prepend_bos is True,
# then there is one attended but non-pad token (bos token) in each sequence
assert attended_but_non_pad_mask.sum() == tokens.shape[0]
else:
# otherwise, there should be no attended but non-pad token
assert attended_but_non_pad_mask.sum() == 0

@pytest.mark.parametrize("prepend_bos", [True, False])
def test_get_causal_mask_for_left_padding(self, model, prepend_bos):
model.tokenizer.padding_side = "left"

prompts = self.prompts
tokens = model.to_tokens(prompts, prepend_bos=prepend_bos)

left_attention_mask = utils.get_attention_mask(
model.tokenizer, tokens, prepend_bos=prepend_bos
) # [batch pos]

final_mask = utils.get_causal_mask_for_left_padding(left_attention_mask)

pad_token_mask = ~left_attention_mask.bool()
assert final_mask[pad_token_mask].sum() == 0

attn = model.blocks[0].attn
causal_pad_mask = ~attn.mask[: tokens.shape[1], : tokens.shape[1]]
assert final_mask[:, causal_pad_mask].sum() == 0
Loading