Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix random_mask_tokenize when the text is long #680

Closed
wants to merge 1 commit into from

Conversation

bryant1410
Copy link
Contributor

Without this patch, the function crashes for long texts. See https://colab.research.google.com/drive/1SHBAUEnI1dNJmXQPUqZekFqXm7xrwH65?usp=sharing

@bryant1410
Copy link
Contributor Author

BTW, wouldn't it be simpler (and probably avoid some longer memory copies?) to, instead of doing:

indices = np.random.permutation(len(tokens)).tolist()
indices = indices[:context_length - 2]
tokens = [tokens[i] for i in indices]

to do:

tokens = random.sample(tokens, context_length - 2)

?

And similarly for other random sampling in this file that uses NumPy methods.

@rwightman
Copy link
Collaborator

rwightman commented Oct 18, 2023

@bryant1410 so you're right about the issue, and random.sample would be a good alternative if tokens are lists as they are right now. Having looked at these tokenizers more closely since merge though, I'm not sure if they've all been used and tested. so currently wondering if we should have them in there in this state.

Clearly this one does not work as implemented, it was written for tokens to be a np array or tensor (where lists of indices are valid), but it's not, tokens is a list. Also, I'm failing to see how this is random masking as in the paper, it looks like a full random shuffle of the tokens. @zw615 ?

@bryant1410
Copy link
Contributor Author

I see. I wait for clarification then, before making more changes to this PR.

@rwightman
Copy link
Collaborator

rwightman commented Oct 18, 2023

Wouldn't this be more appropriate 'random mask' vs current impl which seems to be 'random shuffle'?

N = 10  # context len
B = 1 # batch size (for a batch impl)
mask_rate = 0.3
num_keep = max(1, int(N * (1 - mask_rate)))
indices = torch.argsort(torch.randn(B, N), dim=-1)[:, :num_keep]
# indices = torch.randperm(N)[:num_keep]  # this would be good for unbatched impl, one line at a time
indices = indices.sort(dim=-1)[0]  # back in order

@zw615

@rwightman
Copy link
Collaborator

for version in #660 was thinkin of

class RandomMaskTokenizer(SimpleTokenizer):
    def __init__(
            self,
            bpe_path: str = default_bpe(),
            special_tokens=None,
            clean: str = 'lower',
            shuffle: bool = False,
    ):
        super().__init__(bpe_path, special_tokens, clean)
        self.shuffle = shuffle

    def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
        """
        Returns the tokenized representation of given input string(s)

        Parameters
        ----------
        texts : Union[str, List[str]]
            An input string or a list of input strings to tokenize
        context_length : int
            The context length to use; all CLIP models use 77 as the context length

        Returns
        -------
        A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
        """
        if isinstance(texts, str):
            texts = [texts]

        sot_token = self.encoder["<start_of_text>"]
        eot_token = self.encoder["<end_of_text>"]
        all_tokens = [self.encode(text) for text in texts]
        result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)

        for i, tokens in enumerate(all_tokens):
            tokens = torch.tensor(tokens)
            num_tokens = len(tokens)
            if num_tokens > context_length - 2:  # 2 for sot and eot token
                keep_num = context_length - 2
                indices = torch.randperm(len(tokens))
                indices = indices[:keep_num]
                if not self.shuffle:
                    indices = indices.sort()[0]
                tokens = tokens[indices]
                num_tokens = keep_num
            result[i, 0] = sot_token
            result[i, 1:num_tokens + 1] = tokens
            result[i, num_tokens + 1] = eot_token

        return result

rwightman added a commit that referenced this pull request Oct 18, 2023
…default __call__() arg to None. Clean up reduction masking logic and fix #680
@rwightman rwightman closed this in a5f3ae9 Oct 20, 2023
@bryant1410 bryant1410 deleted the patch-1 branch October 20, 2023 16:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants