You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying various ways to continue from an initial multihead fine-tuning fit. I started in #622, but it's still not working. I'm trying to figure out what combination of input models/checkpoints, selected pre-training configs, and command line arguments will work.
First I tried to extract the default head and pass that as --foundation_model, but that gave some low level error (I can reproduce it if useful), and in fact I'm confused about how it could possibly work, because it's thrown out the fine-tuned pt_head.
Then I tried to restart from a checkpoint (this is iterative fitting, so a larger fitting dataset, but same elements), but that's not working either. If I change nothing except copying in the checkpoint and adding to the command line --restart_latest, it ignores --pt_train_file, picks new pretrained configs, and gives an error.
raise RuntimeError(
RuntimeError: Error(s) in loading state_dict for ScaleShiftMACE:
size mismatch for atomic_numbers: copying a param with shape torch.Size([13]) from checkpoint, the shape in current model is torch.Size([12]).
size mismatch for node_embedding.linear.weight: copying a param with shape torch.Size([1664]) from checkpoint, the shape in current model is torch.Size([1536]).
size mismatch for atomic_energies_fn.atomic_energies: copying a param with shape torch.Size([2, 13]) from checkpoint, the shape in current model is torch.Size([2, 12]).
size mismatch for interactions.0.skip_tp.weight: copying a param with shape torch.Size([212992]) from checkpoint, the shape in current model is torch.Size([196608]).
size mismatch for interactions.1.skip_tp.weight: copying a param with shape torch.Size([212992]) from checkpoint, the shape in current model is torch.Size([196608]).
size mismatch for products.0.symmetric_contractions.contractions.0.weights_max: copying a param with shape torch.Size([13, 23, 128]) from checkpoint, the shape in current model is torch.Size([12, 23, 128]).
size mismatch for products.0.symmetric_contractions.contractions.0.weights.0: copying a param with shape torch.Size([13, 4, 128]) from checkpoint, the shape in current model is torch.Size([12, 4, 128]).
size mismatch for products.0.symmetric_contractions.contractions.0.weights.1: copying a param with shape torch.Size([13, 1, 128]) from checkpoint, the shape in current model is torch.Size([12, 1, 128]).
size mismatch for products.1.symmetric_contractions.contractions.0.weights_max: copying a param with shape torch.Size([13, 23, 128]) from checkpoint, the shape in current model is torch.Size([12, 23, 128]).
size mismatch for products.1.symmetric_contractions.contractions.0.weights.0: copying a param with shape torch.Size([13, 4, 128]) from checkpoint, the shape in current model is torch.Size([12, 4, 128]).
size mismatch for products.1.symmetric_contractions.contractions.0.weights.1: copying a param with shape torch.Size([13, 1, 128]) from checkpoint, the shape in current model is torch.Size([12, 1, 128]).
suggests that if I pass --foundation_model=small it'll ignore --pt_train_file, so I remove the foundation model, but then it doesn't do multihead at all (according to the log), and gives an error loading the checkpoint.
File "/home/Software/python/conda/torch/2.5.1/cpu/lib/python3.11/site-packages/torch/nn/modules/module.py", line 2584, in load_state_dict
raise RuntimeError(
RuntimeError: Error(s) in loading state_dict for ScaleShiftMACE:
size mismatch for atomic_numbers: copying a param with shape torch.Size([13]) from checkpoint, the shape in current model is torch.Size([1]).
size mismatch for node_embedding.linear.weight: copying a param with shape torch.Size([1664]) from checkpoint, the shape in current model is torch.Size([128]).
size mismatch for radial_embedding.bessel_fn.bessel_weights: copying a param with shape torch.Size([10]) from checkpoint, the shape in current model is torch.Size([8]).
size mismatch for atomic_energies_fn.atomic_energies: copying a param with shape torch.Size([2, 13]) from checkpoint, the shape in current model is torch.Size([1,
1]).
size mismatch for interactions.0.conv_tp_weights.layer0.weight: copying a param with shape torch.Size([10, 64]) from checkpoint, the shape in current model is torch.Size([8, 64]).
size mismatch for interactions.0.skip_tp.weight: copying a param with shape torch.Size([212992]) from checkpoint, the shape in current model is torch.Size([65536]).
size mismatch for interactions.0.skip_tp.output_mask: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([2048]).
size mismatch for interactions.1.linear_up.weight: copying a param with shape torch.Size([16384]) from checkpoint, the shape in current model is torch.Size([32768]).
size mismatch for interactions.1.linear_up.output_mask: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for interactions.1.conv_tp.output_mask: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([5120]).
size mismatch for interactions.1.conv_tp_weights.layer0.weight: copying a param with shape torch.Size([10, 64]) from checkpoint, the shape in current model is torch.Size([8, 64]).
size mismatch for interactions.1.conv_tp_weights.layer3.weight: copying a param with shape torch.Size([64, 512]) from checkpoint, the shape in current model is torch.Size([64, 1280]).
size mismatch for interactions.1.linear.weight: copying a param with shape torch.Size([65536]) from checkpoint, the shape in current model is torch.Size([163840]).
size mismatch for interactions.1.skip_tp.weight: copying a param with shape torch.Size([212992]) from checkpoint, the shape in current model is torch.Size([16384]).
size mismatch for products.0.symmetric_contractions.contractions.0.weights_max: copying a param with shape torch.Size([13, 23, 128]) from checkpoint, the shape in
current model is torch.Size([1, 23, 128]).
size mismatch for products.0.symmetric_contractions.contractions.0.weights.0: copying a param with shape torch.Size([13, 4, 128]) from checkpoint, the shape in current model is torch.Size([1, 4, 128]).
size mismatch for products.0.symmetric_contractions.contractions.0.weights.1: copying a param with shape torch.Size([13, 1, 128]) from checkpoint, the shape in current model is torch.Size([1, 1, 128]).
size mismatch for products.0.linear.weight: copying a param with shape torch.Size([16384]) from checkpoint, the shape in current model is torch.Size([32768]).
size mismatch for products.0.linear.output_mask: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for products.1.symmetric_contractions.contractions.0.weights_max: copying a param with shape torch.Size([13, 23, 128]) from checkpoint, the shape in
current model is torch.Size([1, 23, 128]).
size mismatch for products.1.symmetric_contractions.contractions.0.weights.0: copying a param with shape torch.Size([13, 4, 128]) from checkpoint, the shape in current model is torch.Size([1, 4, 128]).
size mismatch for products.1.symmetric_contractions.contractions.0.weights.1: copying a param with shape torch.Size([13, 1, 128]) from checkpoint, the shape in current model is torch.Size([1, 1, 128]).
size mismatch for readouts.0.linear.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
size mismatch for readouts.0.linear.output_mask: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([1]).
size mismatch for readouts.1.linear_1.weight: copying a param with shape torch.Size([4096]) from checkpoint, the shape in current model is torch.Size([2048]).
size mismatch for readouts.1.linear_1.output_mask: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([16]).
size mismatch for readouts.1.linear_2.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([16]).
size mismatch for readouts.1.linear_2.output_mask: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([1]).
size mismatch for scale_shift.scale: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([1]).
size mismatch for scale_shift.shift: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([1]).
The text was updated successfully, but these errors were encountered:
indeed, it should be much easier for this to work. The heart of the problem is the subselection of elements. I think there should be a cleaner way to just keep all the element from the foundation model and than it would just solve that too.
I'm trying various ways to continue from an initial multihead fine-tuning fit. I started in #622, but it's still not working. I'm trying to figure out what combination of input models/checkpoints, selected pre-training configs, and command line arguments will work.
First I tried to extract the
default
head and pass that as--foundation_model
, but that gave some low level error (I can reproduce it if useful), and in fact I'm confused about how it could possibly work, because it's thrown out the fine-tunedpt_head
.Then I tried to restart from a checkpoint (this is iterative fitting, so a larger fitting dataset, but same elements), but that's not working either. If I change nothing except copying in the checkpoint and adding to the command line
--restart_latest
, it ignores--pt_train_file
, picks new pretrained configs, and gives an error.This line
mace/mace/cli/run_train.py
Lines 276 to 279 in 49293b8
--foundation_model=small
it'll ignore--pt_train_file
, so I remove the foundation model, but then it doesn't do multihead at all (according to the log), and gives an error loading the checkpoint.The text was updated successfully, but these errors were encountered: