Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Future reserved tokens and backward (in)compatibility #102

Open
willdumm opened this issue Jan 9, 2025 · 2 comments
Open

Future reserved tokens and backward (in)compatibility #102

willdumm opened this issue Jan 9, 2025 · 2 comments
Assignees

Comments

@willdumm
Copy link
Contributor

willdumm commented Jan 9, 2025

After the paired chain PR #92 we can no longer load old models, because the amino acid embedding has changed size:

RuntimeError: Error(s) in loading state_dict for TransformerBinarySelectionModelWiggleAct:
	size mismatch for amino_acid_embedding.weight: copying a param with shape torch.Size([21, 128]) from checkpoint, the shape in current model is torch.Size([22, 128]).

We will want to add additional reserved tokens to our model inputs in the future. It would be great to not have to throw away all old trained models whenever we do. One way to handle this would be:

  • add all the reserved tokens we can imagine using, even if we don't make use of them yet
  • add model and encoder versions, which will be serialized and checked whenever loading models to ensure that we don't try to run old models with new code that isn't designed to handle them.

It would be nice to figure out a way to keep our model architecture flexible enough to not have to reserve tokens proactively like this. I'll be doing some reading to see if that's possible.

@willdumm
Copy link
Contributor Author

willdumm commented Jan 9, 2025

New ideas:

  • Where we've hardcoded the embedding dimension here
    self.amino_acid_embedding = nn.Embedding(MAX_AA_TOKEN_IDX + 1, self.d_model)
    we will now load the embedding dimension from a model hyperparameter, whose default for loaded models if not present will be 20, and whose default for new models will be as it is now. There are other places where this is hardcoded, and they'll need to be updated to use the hyper parameter also.
  • To avoid backward-incompatibilities when adding new tokens, we should move the ambiguous character directly after the concrete AA characters, and add any new tokens to the very end of this string:
    AA_TOKEN_STR_SORTED = AA_STR_SORTED + RESERVED_TOKENS + "X"
    . This will make the most recently trained models (after the paired PR) unusable, but will allow us to use any models trained before that.
  • dnsm-experiments-1 code will continue to be updated to scaffold input sequences with whatever reserved tokens we like. If an older model is presented with a sequence containing unsupported tokens, we need to figure out what to do with those so the model doesn't see them.
    • We discussed just stripping them out, but I think that will be difficult to handle, because we use the raw sequences in lots of places, for evaluating the neutral model and computing branch lengths. If I were to do it this way, I think I'd suggest stripping out the unknown tokens in the model application, and then inserting NaN's in the corresponding sites in the result so that the result has the same length as the unmodified input sequence.
    • Can we instead mask the sites containing unknown tokens, but still unmask those sites that contain known reserved tokens when evaluating the model? In some cases this could result in sequences being shifted compared to the input data the model was trained on. This could be more of an issue if we start putting reserved tokens in the middle of the sequences. This approach would have the advantage that the size of the output would always match the input sequence length...
    • We have to keep in mind here that although the Dataset class does all the data preparation, it does not have any access to the model class that will be evaluated on the data.
  • As long as we test this approach to make sure it'll be backward-compatible when we add more tokens, we shouldn't have to preemptively add them all now.
  • Model versions don't seem absolutely necessary now, and the hyper parameter describing the dimension of the embedding should only grow, making it a sort of model version itself. However, a more explicit model version may be something we'll want in the future?

As a unit test, we'll use an old trained model and model outputs on a fixed sequence, and compare them to model outputs when we load and evaluate the model with new code.

@matsen
Copy link
Contributor

matsen commented Jan 10, 2025

Thanks for this thoughtful and detailed writeup!

We discussed just stripping them out, but I think that will be difficult to handle, because we use the raw sequences in lots of places, for evaluating the neutral model and computing branch lengths. If I were to do it this way, I think I'd suggest stripping out the unknown tokens in the model application, and then inserting NaN's in the corresponding sites in the result so that the result has the same length as the unmodified input sequence.

First I want to remind you that N sites for the SHM model are not "trivial". The SHM model for the G in AAGCT is different than that for AANNNGCT. This is OK when we are looking at the junction between two chains because the heavy and the light SHM separately, but it wouldn't be true if we had some other special token.

Important: I don't want to get hung up on the "future special token" thing. I think that we should as much as possible just extend the current setup such that it works for the H/L case without setting any huge traps for the future. I think that this could just be

  1. Moving ambiguous tokens to the end
  2. Masking exactly as we are now for the H/L tokens

I'd love for this to be wrapped by EOD Monday at the latest so we can return to the important work of actually training models.

We have to keep in mind here that although the Dataset class does all the data preparation, it does not have any access to the model class that will be evaluated on the data.

A cursory look at when we instantiate Dataset classes indicated that in our main applications, we are making a model and a Dataset class at the same time. Although we should always hesitate when introducing additional dependencies, if it would be useful it seems like one could pass the model or model class to the Dataset constructor.

I also note that for the SHM models we store information about the sequence encoder in the crepe. We should probably call that notion of encoder as a tokenizer or something because it's not the transformer-encoder.

Let's hold off on introducing model versions for now.

Again, thanks! ✨

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants