-
Notifications
You must be signed in to change notification settings - Fork 508
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Tensor Shape Mismatch During DeepMD Training #3949
Comments
This error happens when the model.ntypes=5, len(model.type_map)=4, and data.ntypes=4. @iProzd What are the expected values here? |
@PhelanShao Without access to your training data, it's hard to be certain, but the issue might be related to the strong assumption of Additionally, the @hztttt Am I correct? And if so we should clarify this requirement in documentations of tf. @PhelanShao You can also now use spin model in pytorch, see here. Note that the data format of pytorch/tensorflow are different which is also detailed in the doc above. |
I believe there will be something wrong when a system does not contain atoms with spin. |
Here, deepmd-kit/deepmd/tf/entrypoints/train.py Lines 212 to 216 in 29db791
The question here is the expected |
Yeah. In tensorflow backened spin model, element types with spin must appear before those without spin in the |
It is expected no |
Thank you! Actually, there is no sulfur (S) element in the system. It is actually an oxygen (O) element with a magnetic moment. To differentiate it when marking with fp, I was worried that naming it O_1 or O_A might affect its transmission in dpgen, so I replaced it with S. However, the type_map order in the dpgen process using dpdata conversion also needs adjustment, right? Should I rename this element to ensure it is ordered before C, O, and H? |
This is not a good behavior, anyway. |
Bug summary
When running DeepMD-kit, an error occurs related to a mismatch in tensor shapes. Specifically, the error message indicates that the shape of the input tensor does not match the expected shape, causing a ValueError during training data processing. This issue might be related to the configuration of the magnetic spin parameters in the input files and the corresponding training data.
I suspect the issue might be due to the training data generated from CP2K not containing magnetic spin data. This issue could stem from the fact that the example provided might include data derived from VASP's OSZICAR content, which contains spin data.
DeePMD-kit Version
registry.dp.tech/dptech/deepmd-kit:2024Q1-d23cf3e
Backend and its version
registry.dp.tech/dptech/deepmd-kit:2024Q1-d23cf3e
How did you download the software?
Others (write below)
Input Files, Running Commands, Error Log, etc.
Traceback (most recent call last):
File "/opt/mamba/bin/dp", line 8, in
sys.exit(main())
File "/opt/mamba/lib/python3.10/site-packages/deepmd/main.py", line 807, in main
deepmd_main(args)
File "/opt/mamba/lib/python3.10/site-packages/deepmd/tf/entrypoints/main.py", line 72, in main
train_dp(**dict_args)
File "/opt/mamba/lib/python3.10/site-packages/deepmd/tf/entrypoints/train.py", line 153, in train
_do_work(jdata, run_opt, is_compress)
File "/opt/mamba/lib/python3.10/site-packages/deepmd/tf/entrypoints/train.py", line 265, in _do_work
model.build(train_data, stop_batch, origin_type_map=origin_type_map)
File "/opt/mamba/lib/python3.10/site-packages/deepmd/tf/train/trainer.py", line 284, in build
self.model.data_stat(data)
File "/opt/mamba/lib/python3.10/site-packages/deepmd/tf/model/ener.py", line 128, in data_stat
self._compute_input_stat(
File "/opt/mamba/lib/python3.10/site-packages/deepmd/tf/model/ener.py", line 147, in _compute_input_stat
self.descrpt.compute_input_stats(
File "/opt/mamba/lib/python3.10/site-packages/deepmd/tf/descriptor/se_a.py", line 373, in compute_input_stats
sysr, sysr2, sysa, sysa2, sysn = self._compute_dstats_sys_smth(
File "/opt/mamba/lib/python3.10/site-packages/deepmd/tf/descriptor/se_a.py", line 841, in _compute_dstats_sys_smth
dd_all = run_sess(
File "/opt/mamba/lib/python3.10/site-packages/deepmd/tf/utils/sess.py", line 31, in run_sess
return sess.run(*args, **kwargs)
File "/opt/mamba/lib/python3.10/site-packages/tensorflow/python/client/session.py", line 972, in run
result = self._run(None, fetches, feed_dict, options_ptr,
File "/opt/mamba/lib/python3.10/site-packages/tensorflow/python/client/session.py", line 1189, in _run
raise ValueError(
ValueError: Cannot feed value of shape (6,) for Tensor d_sea_t_natoms:0, which has shape (7,)
Steps to Reproduce
{
"model": {
"type_map": [
"C",
"O",
"H",
"S"
],
"descriptor": {
"type": "se_e2_a",
"sel": [
16,
46,
92,
52
],
"rcut_smth": 1.0,
"rcut": 5.0,
"neuron": [
25,
50,
100
],
"resnet_dt": false,
"axis_neuron": 16,
"seed": 930070626
},
"fitting_net": {
"neuron": [
240,
240,
240
],
"resnet_dt": true,
"seed": 3301444140
},
"spin": {
"use_spin": [
false,
false,
false,
true
],
"virtual_len": [
0.4
],
"spin_norm": [
0.0,
0.0,
0.0,
1.0
],
"_comment4": " that's all"
}
},
"learning_rate": {
"type": "exp",
"start_lr": 0.001,
"decay_steps": 10000
},
"loss": {
"type": "ener_spin",
"start_pref_e": 0.1,
"limit_pref_e": 2,
"start_pref_fr": 1000,
"limit_pref_fr": 1.0,
"start_pref_fm": 10000,
"limit_pref_fm": 10.0,
"start_pref_v": 0,
"limit_pref_v": 0,
"_comment7": " that's all"
},
"training": {
"stop_batch": 50000,
"disp_file": "lcurve.out",
"disp_freq": 500,
"numb_test": 1,
"save_freq": 500,
"save_ckpt": "model.ckpt",
"disp_training": true,
"time_training": true,
"profiling": false,
"profiling_file": "timeline.json",
"_comment": "that's all",
"training_data": {
"systems": [
"../data.init/init/final1/training_data",
"../data.init/init/final2/training_data",
"../data.init/init/final3/training_data",
"../data.iters/iter.000000/02.fp/data.000",
"../data.iters/iter.000000/02.fp/data.001",
"../data.iters/iter.000000/02.fp/data.002",
"../data.iters/iter.000001/02.fp/data.000",
"../data.iters/iter.000001/02.fp/data.001",
"../data.iters/iter.000001/02.fp/data.002",
"../data.iters/iter.000002/02.fp/data.000",
"../data.iters/iter.000002/02.fp/data.001",
"../data.iters/iter.000002/02.fp/data.002",
"../data.iters/iter.000003/02.fp/data.000",
"../data.iters/iter.000003/02.fp/data.001",
"../data.iters/iter.000003/02.fp/data.002",
"../data.iters/iter.000004/02.fp/data.000",
"../data.iters/iter.000004/02.fp/data.001",
"../data.iters/iter.000004/02.fp/data.002",
"../data.iters/iter.000005/02.fp/data.000",
"../data.iters/iter.000005/02.fp/data.001",
"../data.iters/iter.000005/02.fp/data.002",
"../data.iters/iter.000006/02.fp/data.000",
"../data.iters/iter.000006/02.fp/data.001",
"../data.iters/iter.000006/02.fp/data.002",
"../data.iters/iter.000007/02.fp/data.000",
"../data.iters/iter.000007/02.fp/data.001",
"../data.iters/iter.000007/02.fp/data.002",
"../data.iters/iter.000009/02.fp/data.000",
"../data.iters/iter.000010/02.fp/data.000",
"../data.iters/iter.000011/02.fp/data.000",
"../data.iters/iter.000011/02.fp/data.002",
"../data.iters/iter.000012/02.fp/data.000",
"../data.iters/iter.000012/02.fp/data.002",
"../data.iters/iter.000013/02.fp/data.000",
"../data.iters/iter.000013/02.fp/data.002",
"../data.iters/iter.000014/02.fp/data.000",
"../data.iters/iter.000014/02.fp/data.002",
"../data.iters/iter.000015/02.fp/data.000",
"../data.iters/iter.000015/02.fp/data.002",
"../data.iters/iter.000016/02.fp/data.000",
"../data.iters/iter.000016/02.fp/data.002",
"../data.iters/iter.000017/02.fp/data.000",
"../data.iters/iter.000017/02.fp/data.002",
"../data.iters/iter.000018/02.fp/data.000",
"../data.iters/iter.000018/02.fp/data.002",
"../data.iters/iter.000019/02.fp/data.000",
"../data.iters/iter.000019/02.fp/data.002",
"../data.iters/iter.000020/02.fp/data.000",
"../data.iters/iter.000020/02.fp/data.002",
"../data.iters/iter.000021/02.fp/data.000",
"../data.iters/iter.000021/02.fp/data.002",
"../data.iters/iter.000022/02.fp/data.000"
],
"batch_size": [
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1
]
},
"seed": 2826335128
}
}
Further Information, Files, and Links
No response
The text was updated successfully, but these errors were encountered: