Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Language Translation with nn.Transformer and torchtext[BUG] - mask with -inf leads to nans. #2988

Open
danielegana opened this issue Jul 31, 2024 · 1 comment

Comments

@danielegana
Copy link

danielegana commented Jul 31, 2024

Add Link

https://pytorch.org/tutorials/beginner/translation_transformer.html

Describe the bug

Running the tutorial on language translation with transformers leads to nans when training on the first batch iteration on the first epoch, and even when evaluating an untrained model for some input sequences.

I find this issue simply by copy-pasting the tutorial to my local computer and starting the training process. The issue seems to stem from the target mask. Replacing the line

mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))

by

mask = mask.float().masked_fill(mask == 0, float('-1e9')).masked_fill(mask == 1, float(0.0))

allows to train the model for a few batches on the first epoch without nan outputs. A problem that is possibly related has been pointed out in pytorch/pytorch#41508 (comment)

However, even with this "fix", the losses of the model increase with training, and eventually they become nan too.

Describe your environment

Running on MacOS. I am using pytorch 2.2.2 and python 3.9.7.

cc @pytorch/team-text-core @Nayef211

@lmntrx-sys
Copy link

Suggested Steps
Review Mask Implementation: Investigate the mask implementation in the tutorial and ensure it's correctly applied during training and evaluation.

Gradient Clipping: Implement gradient clipping to prevent exploding gradients, which could lead to NaNs.

Learning Rate: Experiment with different learning rates to see if a lower learning rate stabilizes the training process.

Check Data Preprocessing: Ensure that the input data is correctly preprocessed and normalized

Monitor NaN Values: Add checks to monitor NaN values during training and identify when and where they first appear

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants