Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

restart multihead fine-tuning fit #773

Open
bernstei opened this issue Jan 7, 2025 · 3 comments
Open

restart multihead fine-tuning fit #773

bernstei opened this issue Jan 7, 2025 · 3 comments
Labels
enhancement New feature or request

Comments

@bernstei
Copy link
Collaborator

bernstei commented Jan 7, 2025

I'm trying various ways to continue from an initial multihead fine-tuning fit. I started in #622, but it's still not working. I'm trying to figure out what combination of input models/checkpoints, selected pre-training configs, and command line arguments will work.

First I tried to extract the default head and pass that as --foundation_model, but that gave some low level error (I can reproduce it if useful), and in fact I'm confused about how it could possibly work, because it's thrown out the fine-tuned pt_head.

Then I tried to restart from a checkpoint (this is iterative fitting, so a larger fitting dataset, but same elements), but that's not working either. If I change nothing except copying in the checkpoint and adding to the command line --restart_latest, it ignores --pt_train_file, picks new pretrained configs, and gives an error.

    raise RuntimeError(
RuntimeError: Error(s) in loading state_dict for ScaleShiftMACE:
        size mismatch for atomic_numbers: copying a param with shape torch.Size([13]) from checkpoint, the shape in current model is torch.Size([12]).
        size mismatch for node_embedding.linear.weight: copying a param with shape torch.Size([1664]) from checkpoint, the shape in current model is torch.Size([1536]).
        size mismatch for atomic_energies_fn.atomic_energies: copying a param with shape torch.Size([2, 13]) from checkpoint, the shape in current model is torch.Size([2, 12]).
        size mismatch for interactions.0.skip_tp.weight: copying a param with shape torch.Size([212992]) from checkpoint, the shape in current model is torch.Size([196608]).
        size mismatch for interactions.1.skip_tp.weight: copying a param with shape torch.Size([212992]) from checkpoint, the shape in current model is torch.Size([196608]).
        size mismatch for products.0.symmetric_contractions.contractions.0.weights_max: copying a param with shape torch.Size([13, 23, 128]) from checkpoint, the shape in current model is torch.Size([12, 23, 128]).
        size mismatch for products.0.symmetric_contractions.contractions.0.weights.0: copying a param with shape torch.Size([13, 4, 128]) from checkpoint, the shape in current model is torch.Size([12, 4, 128]).
        size mismatch for products.0.symmetric_contractions.contractions.0.weights.1: copying a param with shape torch.Size([13, 1, 128]) from checkpoint, the shape in current model is torch.Size([12, 1, 128]).
        size mismatch for products.1.symmetric_contractions.contractions.0.weights_max: copying a param with shape torch.Size([13, 23, 128]) from checkpoint, the shape in current model is torch.Size([12, 23, 128]).
        size mismatch for products.1.symmetric_contractions.contractions.0.weights.0: copying a param with shape torch.Size([13, 4, 128]) from checkpoint, the shape in current model is torch.Size([12, 4, 128]).
        size mismatch for products.1.symmetric_contractions.contractions.0.weights.1: copying a param with shape torch.Size([13, 1, 128]) from checkpoint, the shape in current model is torch.Size([12, 1, 128]).

This line

mace/mace/cli/run_train.py

Lines 276 to 279 in 49293b8

if (
args.foundation_model in ["small", "medium", "large"]
or args.pt_train_file == "mp"
):
suggests that if I pass --foundation_model=small it'll ignore --pt_train_file, so I remove the foundation model, but then it doesn't do multihead at all (according to the log), and gives an error loading the checkpoint.

  File "/home/Software/python/conda/torch/2.5.1/cpu/lib/python3.11/site-packages/torch/nn/modules/module.py", line 2584, in load_state_dict
    raise RuntimeError(
RuntimeError: Error(s) in loading state_dict for ScaleShiftMACE:
        size mismatch for atomic_numbers: copying a param with shape torch.Size([13]) from checkpoint, the shape in current model is torch.Size([1]).
        size mismatch for node_embedding.linear.weight: copying a param with shape torch.Size([1664]) from checkpoint, the shape in current model is torch.Size([128]).
        size mismatch for radial_embedding.bessel_fn.bessel_weights: copying a param with shape torch.Size([10]) from checkpoint, the shape in current model is torch.Size([8]).
        size mismatch for atomic_energies_fn.atomic_energies: copying a param with shape torch.Size([2, 13]) from checkpoint, the shape in current model is torch.Size([1,
1]).
        size mismatch for interactions.0.conv_tp_weights.layer0.weight: copying a param with shape torch.Size([10, 64]) from checkpoint, the shape in current model is torch.Size([8, 64]).
        size mismatch for interactions.0.skip_tp.weight: copying a param with shape torch.Size([212992]) from checkpoint, the shape in current model is torch.Size([65536]).
        size mismatch for interactions.0.skip_tp.output_mask: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([2048]).
        size mismatch for interactions.1.linear_up.weight: copying a param with shape torch.Size([16384]) from checkpoint, the shape in current model is torch.Size([32768]).
        size mismatch for interactions.1.linear_up.output_mask: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
        size mismatch for interactions.1.conv_tp.output_mask: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([5120]).
        size mismatch for interactions.1.conv_tp_weights.layer0.weight: copying a param with shape torch.Size([10, 64]) from checkpoint, the shape in current model is torch.Size([8, 64]).
        size mismatch for interactions.1.conv_tp_weights.layer3.weight: copying a param with shape torch.Size([64, 512]) from checkpoint, the shape in current model is torch.Size([64, 1280]).
        size mismatch for interactions.1.linear.weight: copying a param with shape torch.Size([65536]) from checkpoint, the shape in current model is torch.Size([163840]).
        size mismatch for interactions.1.skip_tp.weight: copying a param with shape torch.Size([212992]) from checkpoint, the shape in current model is torch.Size([16384]).
        size mismatch for products.0.symmetric_contractions.contractions.0.weights_max: copying a param with shape torch.Size([13, 23, 128]) from checkpoint, the shape in
current model is torch.Size([1, 23, 128]).
        size mismatch for products.0.symmetric_contractions.contractions.0.weights.0: copying a param with shape torch.Size([13, 4, 128]) from checkpoint, the shape in current model is torch.Size([1, 4, 128]).
        size mismatch for products.0.symmetric_contractions.contractions.0.weights.1: copying a param with shape torch.Size([13, 1, 128]) from checkpoint, the shape in current model is torch.Size([1, 1, 128]).
        size mismatch for products.0.linear.weight: copying a param with shape torch.Size([16384]) from checkpoint, the shape in current model is torch.Size([32768]).
        size mismatch for products.0.linear.output_mask: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
        size mismatch for products.1.symmetric_contractions.contractions.0.weights_max: copying a param with shape torch.Size([13, 23, 128]) from checkpoint, the shape in
current model is torch.Size([1, 23, 128]).
        size mismatch for products.1.symmetric_contractions.contractions.0.weights.0: copying a param with shape torch.Size([13, 4, 128]) from checkpoint, the shape in current model is torch.Size([1, 4, 128]).
        size mismatch for products.1.symmetric_contractions.contractions.0.weights.1: copying a param with shape torch.Size([13, 1, 128]) from checkpoint, the shape in current model is torch.Size([1, 1, 128]).
        size mismatch for readouts.0.linear.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
        size mismatch for readouts.0.linear.output_mask: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([1]).
        size mismatch for readouts.1.linear_1.weight: copying a param with shape torch.Size([4096]) from checkpoint, the shape in current model is torch.Size([2048]).
        size mismatch for readouts.1.linear_1.output_mask: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([16]).
        size mismatch for readouts.1.linear_2.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([16]).
        size mismatch for readouts.1.linear_2.output_mask: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([1]).
        size mismatch for scale_shift.scale: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([1]).
        size mismatch for scale_shift.shift: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([1]).
@ilyes319
Copy link
Contributor

ilyes319 commented Jan 8, 2025

indeed, it should be much easier for this to work. The heart of the problem is the subselection of elements. I think there should be a cleaner way to just keep all the element from the foundation model and than it would just solve that too.

@bernstei
Copy link
Collaborator Author

bernstei commented Jan 8, 2025

I think there should be a cleaner way to just keep all the element from the foundation model

I would like this as well, since it would also make it easier to add new element in the middle of an iterative fit

@ilyes319
Copy link
Contributor

ilyes319 commented Jan 13, 2025

@bernstei I have made that functionality in the latest dev branch, if you can test it. Just need to add: --foundation_model_elements=True.

@ilyes319 ilyes319 added the enhancement New feature or request label Jan 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants