Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for left padding #344

Merged
merged 25 commits into from
Sep 1, 2023

Conversation

soheeyang
Copy link
Contributor

@soheeyang soheeyang commented Jul 16, 2023

Description

This PR introduces support for left padding for decoder-only models by properly adjusting the absolute position embedding and causal/attention masks. This change enables an easier retrieval of the final outputs, e.g., logits[:, -1, :], getting rid of the need to perform a complicated indexing, even when a batch of varying lengths of sequences are processed.

  • Introduction of util.get_attention_mask() Method: Introduced a new method that is called internally by the model in forward() to generate attention masks based on the input tokens (works for both left-padded or right-padded tokens) and prepend_bos flag used to create the token. This mask ensures that the model does not pay attention to padding tokens during processing; while it is not necessary when right padding is used, it becomes necessary when left padding is used to adjust the absolute position embedding and attention. The method also handles a special case where the BOS token is identical to the pad token and is prepended to the input sequences, so that such a token can be properly attended.

    • Design consideration: One thing to note is that prepend_bos to pass to the function should be the same with the one used to create tokens. At first, I considered making the function infer the value of prepend_bos by investigating tokens without the need to pass the value as a parameter, but then it requires several assumptions on how tokens is processed, which might not be always true. Therefore, I decided to go with the current design.
  • Adjustment of Positional Embeddings: Made changes to the positional embeddings such that the first real token in each sequence receives the 0th positional embedding. This is critical to the proper functioning of the model when left padding is used. The positional embeddings of the pad tokens are set to zero vectors.

    • Design consideration: I separated the calculation of the positional embeddings for the left padding from that of the right padding for computational efficiency.
  • Adjustment of Causal Masking: Made changes to the causal mask such that the causal mask also prevents the model from attending to the pad tokens that appears at the front of the real tokens.

    • Design consideration: Again, I separated the calculation for the left padding from that of the right padding for computational efficiency.
  • Test Cases: Added comprehensive test cases to ensure the correctness of the left padding feature, attention mask generation, adjusted positional embeddings, adjusted causal mask, and the same output behaviours for all of right padded tokens, left padded tokens, and single batch tokens.

  • Explanation: Added an explanation of the support for the left padding in the exploratory analysis demo.

Fixes #240

Type of change

  • New feature (non-breaking change which adds functionality)

Example code

before

# Requires tedious indexing.
# If you want to get the last token outputs of not only the logits,
# but also the cached activations, etc., the need to do this kind of indexing every time can become very annoying.

model.tokenizer.padding_side = "right"

num_str_tokens_list = [len(t) for t in model.to_str_tokens(prompts)]
logits, cache = model.run_with_cache(prompts)
last_token_positions = torch.tensor(num_str_tokens_list, device=logits.device) - 1
last_logits = logits[torch.arange(len(prompts)), last_token_positions, :].squeeze(1)

after

# life made easier

model.tokenizer.padding_side = "left"

logits, cache = model.run_with_cache(prompts)
last_logits = left_logits[:, -1, :]

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

@jbloomAus
Copy link
Collaborator

Thanks! I'm on holiday but @JayBaileyCS will take a look and I'll also look once he's done that since we've discussed this feature before :)

@soheeyang
Copy link
Contributor Author

soheeyang commented Jul 16, 2023

Hi Joseph,

Thank you for the comment! I've gone through the discussions you had with @JayBaileyCS on this feature. One thing that I want to discuss is about what you wrote in #291 (comment).

I think there are mainly three options on how to enable left padding:

(BTW, I think padding_side is better than left_pad because it matches with what transformers.Tokenizer actually uses, so I will use padding_side in the example codes below.)

  1. Add padding_side="right" argument to all the methods of HookedTransformer that preprocess tokens (these methods are the methods which use prepend_bos as a parameter, so we can easily locate them), so that left padding is used when padding_side="left" is passed as the argument.
  2. Add set_default_padding_side() method to HookedTransformer that the user can call to change the padding side. The default value of the attribute is determined by default_padding_side = "right" of HookedTransformerConfig.
    # in HookedTransformer class
    
    @property
    def padding_side(self):
        return self.tokenizer.padding_side
    
    def set_default_padding_side(self, padding_side: str):
        assert padding_side in ['left', 'right], f"padding_side must be either 'left' or 'right', but {padding_side} is given"
        self.tokenizer.padding_side = padding_side
    • This is exactly what I added (and got merged to the main branch by @neelnanda-io yesterday) for Introduce Global prepend_bos Attribute to HookedTransformer #343 so that option 1 (the original code) and 2 (newly added code) are used together to solve the issue (users can pass padding_side="left" as the argument to the string preprocessing methods to override the default model.padding_side).
  3. Add default_padding_side: str = "right" to HookedTransformerConfig and let HookedTransformer instance access this value to determine the default padding side.
    # in HookedTransformerConfig class
    
    default_padding_side: str = "right"
    
    def __post_init__(self):
        assert self.default_padding_side in ['left', 'right], f"padding_side must be either 'left' or 'right', but {self.default_padding_side} is given"
    • This might be better than using a setter (option 2) to set the default value (I think we should choose between option 1+2 or option 1+3) if we want to nudge the user to set the default padding side only once when HookedTransformer instance is created. Why would we want to do this? To prevent the user from shooting their foot by using inconsistent padding sides without being aware of it by accidentally not calling set_default_padding_side() on time, etc. (I also made this suggestion in the PR for prepend_bos: Introduce Global prepend_bos Attribute to HookedTransformer #343 (comment).)

@soheeyang
Copy link
Contributor Author

soheeyang commented Jul 16, 2023

I initially wrote in the above comment that I'd prefer option 1+2, but I changed my mind to option 1+3 (model.cfg.default_padding_side is used as default, which is right by default but can be set differently at the time of model creation as model = HookedTransformer.from_pretrained("gpt2", default_padding_side="left") but users can locally override when making a method call as model.run_from_cache(prompt, padding_side="right")). The reason is that as I described in the last part of option 3, it is not a good idea to make the user call a setter for a default value separately after the instance is initiated. For example, users can make this kind of mistake when we go with option 2 instead of option 3 when they are not careful enough:

# (Let's assume that I am using Jupyter notebook.)
[1] model = HookedTransformer.from_pretrained("gpt2")

# (This logits and cache are calculated with right padding.)
[2] movie_review_logits, movie_review_cache = model.run_with_cache(movie_reviews)

# ... (I do a lot of things and forget the fact that movie_review_logits and movie_review_cache used right padding)

# Oh! I just realised that I can use left padding as default.
# It would be much more convenient to use left padding as default.
[203] model.set_default_padding_side("left")

# (This logits and cache are calculated with left padding.)
[204] book_review_logits, book_review_cache = model.run_with_cache(book_reviews)

...

# Okay, now let's investigate the last token activations of the model...
[246] movie_review_last_pre8 = movie_review_cache['pre8'][:, -1, :]  # This is an unintentional bug!!!!
      book_review_last_pre8 = book_review_cache['pre8'][:, -1, :]
      ...

In addition, it might be confusing for the code readers when there are model.cfg.default_padding_side, model.padding_side, and model.to_str(padding_side=...); e.g., it is not explicit how model.cfg.default_padding_side and model.padding_side are different without reading the code thoroughly. Using option 1+3, we can just use model.cfg.default_padding_side and model.to_str(padding_side=...), which is much more clear.

@jbloomAus @JayBaileyCS How do you think about going with 1+3? (Whether we go with option 1+2 or 1+3, it would be good to use the same interface as the one for prepend_bos for API consistency. The current interface for prepend_bos is option 1+2, but I'd like to change it to option 1+3 as well.)

@neelnanda-io For exactly same reasons, I think that it would be better to change prepend_bos to also use option 3 rather than option 2 (so that it becomes option 1+3). How do you think about this? If you agree with this suggestion, I will make a quick separate PR for prepend_bos as well.

@neelnanda-io
Copy link
Collaborator

Oh lol I didn't realise you'd done 2 rather than 3 - I agree that 3 is better and would be happy to get a quick PR shifting prepend_bos to option 3

@soheeyang
Copy link
Contributor Author

soheeyang commented Jul 16, 2023

@neelnanda-io Oh no, I just realised that I didn't even explain the difference between option 3 and PR #343 in my comment #344 (comment). Sorry for making you (and possibly others as well) confused 😂 The code I put as the description for option 3 is also applied in PR #343 which is why you said "I didn't realise you'd done 2PR #343 rather than 3".

The thing that I really wanted to change from PR #343 was when and where the default value from the user is given to the model and is set. To be more specific, I want to make the following changes:

  1. Get rid of model.set_default_prepend_bos() and model.prepend_bos.
  2. If the user wants to change the default prepend_bos value to False, instead of calling model.set_default_prepend_bos(), they should pass the value when they create the model instance. Then, the model will set the value to model.cfg.default_prepend_bos. If the user does not provide the default value, True is used.
  3. Instead of having a separate model.prepend_bos, the model directly accesses to self.cfg.default_prepend_bos to check the default setup (which can be locally overriden by prepend_bos=... given to the method call of string processing methods).

The rationale for this change is described in #344 (comment).

@neelnanda-io
Copy link
Collaborator

neelnanda-io commented Jul 16, 2023 via email

@soheeyang
Copy link
Contributor Author

Thank you so much for the feedback! I'll make the quick PR when I wake up (it's time for 🛌 now haha).

@jbloomAus @JayBaileyCS I will do the same thing with padding_side as well and make an update to this PR afterwards.

@soheeyang
Copy link
Contributor Author

soheeyang commented Jul 17, 2023

@neelnanda-io I just thought of an idea that can nicely handle all the overriden flags without touching the code of the actual methods. How do you find this? I originally came up with this to guarantee the reset of the default value of tokenizer.padding_side, but I think we can just handle both of prepend_bos and padding_side together here. The advantage of using this for prepend_bos is that others who make changes to the code of the methods don't need to worry about things like the position of these overriden flags or so, and it's much more simple with just one decorator, removing all the duplicated code of calling override_or_use_default_flag(prepend_bos, self.cfg.default_prepend_bos)!

# in utils.py

def locally_override_and_restore_defaults(function):
    """
    This decorator is used to override the parameters with default values during a function's execution.
    It guarantees that the default parameters are not changed after the function execution,
    even when an error occurs during the execution.
    """
    def wrapper(self, *args, **kwargs):
        sig = inspect.signature(function)
        arg_names = list(sig.parameters.keys())
        if arg_names[0] == "self":
            arg_names = arg_names[1:]

        arg_values = dict(zip(arg_names, args))
        arg_values.update(kwargs)

        # prepend_bos
        # Prepare the overriden value
        default_prepend_bos = self.cfg.default_prepend_bos
        prepend_bos = arg_values.pop("prepend_bos", None)
        assert prepend_bos in [None, True, False], (
            f"prepend_bos must be one of None, True, or False, but got {prepend_bos}."
        )
        arg_values["prepend_bos"] = override_or_use_default_flag(default_prepend_bos, override=prepend_bos)

        # padding_side
        # Prepare the overriden value
        default_padding_side = self.tokenizer.padding_side
        padding_side = arg_values.pop("padding_side", None)
        assert padding_side in [None, "left", "right"], (
            f"padding_side must be one of None, 'left', or 'right', but got {padding_side}."
        )
        
        arg_values["padding_side"] = override_or_use_default_flag(default_padding_side, override=padding_side)

        try:
            # This is important because self.tokenizer.padding_side is
            # the actual padding_side used by Transformers Tokenizer.
            self.tokenizer.padding_side = arg_values["padding_side"]
            
            # Execute the original function with the overridden padding_side

            outputs = function(self, **arg_values)
            
            # Reset the padding_side of the tokenizer
            self.tokenizer.padding_side = default_padding_side
            return outputs

        except Exception as e:
            # If an error occurs, reset the padding_side of the tokenizer before propagating the exception
            self.tokenizer.padding_side = default_padding_side
            raise e

    return wrapper

# in HookedTransformer class

    @utils.locally_override_and_restore_defaults  # This is all we need!
    def to_str_tokens(
        self,
        input: Union[
            str,
            Int[torch.Tensor, "pos"],
            Int[torch.Tensor, "1 pos"],
            Int[np.ndarray, "pos"],
            Int[np.ndarray, "1 pos"],
            list,
        ],
        prepend_bos: Optional[bool] = None,
        padding_side: Optional[Literal["left", "right"]] = None,
    ) -> Union[List[str], List[List[str]]]:

Copy link
Collaborator

@JayBaileyCS JayBaileyCS left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this looks quite good! I'm humbled by the skill that went into this, having tried and failed to create this feature myself. There's quite a few things in here that I'm going to try to learn from, like the use of .signature and creating a custom decorator to solve this one.

I've made a couple small comments, but overall I am very impressed!

transformer_lens/utils.py Outdated Show resolved Hide resolved
transformer_lens/utils.py Outdated Show resolved Hide resolved
is_pad_token = 1 - attetnion_mask.int()

# Find the position of the pad token used as the BOS token and thus should get attended
pad_bos_positions = is_pad_token.cumsum(dim=-1).argmax(dim=-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this will break when the BOS EOS and PAD tokens are the same, as is the case in eg GPT-2. It would be good to add a test for this case. EOS is often used to separate documents in the same context, eg "The cat sat on the matThe dog sat on the log". I would add a separate check that the pad tokens are contiguous (eg I believe ((is_pad_token.cumsum(-1) * ((1-is_pad_token).cumsum(-1)>0)).argmax(dim=-1) should work?) I think you should also check for the case where there's no PAD/BOS/EOS in the context, and add a test for correct handling. And this can be different by row, ugh.

This is very fiddly, sorry for all the nitpicks!

Copy link
Contributor Author

@soheeyang soheeyang Jul 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, the assumption here was that all the PAD tokens would be contiguously at the beginning and no PAD token would appear in the middle. I didn't consider the case of PAD token being an EOS token and appearing at the middle of the text. The suggestion of checking if they are contiguous sounds great. Thank you so much for pointing out the exception case!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved by commit 0591924

Copy link
Contributor Author

@soheeyang soheeyang Aug 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: 0591924 always considers (regardless of whether pad == bos or pad == sep) only the leftmost leading pads (when padding_side == left) or rightmost trailing pads (when padding_side == right) as real pad tokens that should not be attended.
Rational: Sep tokens should be attended, and there might be the cases where users use sep tokens multiple times.

@neelnanda-io
Copy link
Collaborator

Thanks so much for the PR, I've wanted this feature for a while and it's a pain in the ass to implement

@soheeyang
Copy link
Contributor Author

soheeyang commented Jul 25, 2023

Overall this looks quite good! I'm humbled by the skill that went into this, having tried and failed to create this feature myself. There's quite a few things in here that I'm going to try to learn from, like the use of .signature and creating a custom decorator to solve this one.

I've made a couple small comments, but overall I am very impressed!

Hi @JayBaileyCS, Thank you so much for the compliment! Actually, the use of .signature and custom decorator makes the code complicated and harder to read, which increases the risk of potential bug. I used the approach just because I couldn't think of a simpler one, but @neelnanda-io just suggested me a brilliant idea of using a context manager instead, which makes the code simpler and less bug-prone, getting rid of the need of complicated function signature parting but can do the exactly the same thing. I'll make this update and commit it soon!

@jbloomAus jbloomAus added the seen_by_maintainers Confirms that a maintainer is aware of this card. label Jul 26, 2023
@neelnanda-io
Copy link
Collaborator

Hey! I wanted to check up on the progress of this @soheeyang ?

@soheeyang
Copy link
Contributor Author

soheeyang commented Aug 13, 2023

Hi @neelnanda-io, thank you so much for the reminder! I've made all the necessary updates except for some minor things, but my schedule has been so packed these days and I have completely forgotten about this 😂 I will wrap it up today or tomorrow!

# past_kv_cache is not None, so we're doing caching.
# We need to extend the past_left_attention_mask.
# Append '1' to the right of the past_left_attention_mask to account for the new tokens
left_attention_mask = utils.extend_tensor_with_ones(
Copy link
Contributor Author

@soheeyang soheeyang Aug 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't think of a way that avoids introducing past_left_attention_mask which is necessary for index > 0 steps of generate(use_past_kv_cache=True). Does anyone have a better idea?

raise ValueError(
f"Invalid positional_embedding_type passed in {self.cfg.positional_embedding_type}"
)
with utils.LocallyOverridenDefaults(
Copy link
Contributor Author

@soheeyang soheeyang Aug 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@neelnanda-io I changed the decorator to a context manager; a downside of using a context manager is that it is more cumbersome to add than a decorator. (I intentionally avoided overriding the defaults only partially in the methods to prevent any possible inconsistency)

Context manager that allows temporary overriding of default values within a model.
Once the context is exited, the default values are restored.

WARNING: This context manager must be used for any function/method that directly accesses
Copy link
Contributor Author

@soheeyang soheeyang Aug 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit worried about the cases where people forget to add this context manager when they need to, but it's hard to enforce it... One possible way to do this is to do something like the following code, but since padding_side is a property of Tokenizer, how to deal with it is not very straightforward.

@dataclass
class HookedTransformerConfig:
    ...
    _default_prepend_bos: bool = field(default=True, repr=False)
    _inside_context: bool = field(default=False, init=False, repr=False)

    @property
    def default_prepend_bos(self):
        if not self._inside_context:
            raise ValueError("Direct access to default_prepend_bos without context manager is prohibited.")
        return self._default_prepend_bos

    def set_inside_context(self, state: bool):
        self._inside_context = state

Copy link
Collaborator

@neelnanda-io neelnanda-io left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me on a skim, but I don't have the time to read in detail. @JayBaileyCS do you have the time to look at this? If not, I'll just approve.

@JayBaileyCS
Copy link
Collaborator

@neelnanda-io The change has now gotten too big and involved for me to be confident in my own judgment of it, in all honesty. I don't have a good enough understanding of the HookedTransformer in order to track all the variables that have shifted. If it looks good to you on a skim, that's probably better than what I could do for this one, so if you want to approve it I'd say go ahead.

@neelnanda-io
Copy link
Collaborator

Fair enough! Thanks Jay. @soheeyang is it ready to be merged in?

@soheeyang
Copy link
Contributor Author

Fair enough! Thanks Jay. @soheeyang is it ready to be merged in?

@neelnanda-io Yes, it's ready!

Now that we can use left padding, I want to make an enhancement in generate so that it is possible to generate

  • multiple strings -> strings
  • batched tensors -> strings.

However, I will make it as another PR rather than making the change in this PR.

@neelnanda-io neelnanda-io merged commit 68bdb6d into TransformerLensOrg:main Sep 1, 2023
4 checks passed
@UFO-101 UFO-101 mentioned this pull request Sep 18, 2023
10 tasks
@@ -84,15 +87,16 @@ def __init__(
self.cfg = cfg

if tokenizer is not None:
self.set_tokenizer(tokenizer)
self.set_tokenizer(tokenizer, default_padding_side=default_padding_side)
Copy link

@mivanit mivanit Oct 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(not a review, but a question)

What is the reason for overwriting the tokenizer padding direction like this? This caused some existing trained (left-padded) models to break in a very confusing way, since the tokenizer property we were setting was being overridden without us noticing.

(in reference to #344 )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason was that left-padded models were technically not supported previous to this commit, which means that only right-padded models were supported as default. Left-padded models would not have worked correctly previous to this PR.

Could you give me more details on what model you used and how the model was breaking?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We didn't seem to have problems training and running left-padded models, but perhaps there were underlying issues which we did not notice.

I managed to fix the issue by modifying the tokenizer in our model's __init__ after it gets modified by the new HookedTransformer.__init__
understanding-search/maze-transformer#195

mivanit added a commit to understanding-search/maze-transformer that referenced this pull request Oct 4, 2023
as of transformer_lens 1.6.1, in pr TransformerLensOrg/TransformerLens#344
the padding size works differently, and our left-padding was being overriden. this fixes it
by doing some things in `ZanjHookedTransformer.__init__` and in the `HuggingMazeTokenizer`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
seen_by_maintainers Confirms that a maintainer is aware of this card.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Proposal] Add support for padding/masking to the attention computation (Minor feature request)
5 participants