Skip to content

Commit

Permalink
fix(pt): set weights_only=True for torch.load (#4147)
Browse files Browse the repository at this point in the history
Fix #4143.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Enhanced model loading efficiency by only loading model weights, which
reduces memory usage and improves performance.

- **Bug Fixes**
- Streamlined the loading process across various components, ensuring
that only essential model parameters are loaded, thus optimizing the
overall functionality.

- **Tests**
- Updated tests to reflect the new loading behavior, ensuring that only
model weights are loaded in various test scenarios for improved clarity
and performance.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
njzjz and pre-commit-ci[bot] authored Oct 23, 2024
1 parent dccb0e5 commit 911f41b
Show file tree
Hide file tree
Showing 10 changed files with 29 additions and 15 deletions.
8 changes: 6 additions & 2 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,9 @@ def train(
# update init_model or init_frz_model config if necessary
if (init_model is not None or init_frz_model is not None) and use_pretrain_script:
if init_model is not None:
init_state_dict = torch.load(init_model, map_location=DEVICE)
init_state_dict = torch.load(
init_model, map_location=DEVICE, weights_only=True
)
if "model" in init_state_dict:
init_state_dict = init_state_dict["model"]
config["model"] = init_state_dict["_extra_state"]["model_params"]
Expand Down Expand Up @@ -380,7 +382,9 @@ def change_bias(
output: Optional[str] = None,
):
if input_file.endswith(".pt"):
old_state_dict = torch.load(input_file, map_location=env.DEVICE)
old_state_dict = torch.load(
input_file, map_location=env.DEVICE, weights_only=True
)
model_state_dict = copy.deepcopy(old_state_dict.get("model", old_state_dict))
model_params = model_state_dict["_extra_state"]["model_params"]
elif input_file.endswith(".pth"):
Expand Down
4 changes: 3 additions & 1 deletion deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ def __init__(
self.output_def = output_def
self.model_path = model_file
if str(self.model_path).endswith(".pt"):
state_dict = torch.load(model_file, map_location=env.DEVICE)
state_dict = torch.load(
model_file, map_location=env.DEVICE, weights_only=True
)
if "model" in state_dict:
state_dict = state_dict["model"]
self.input_param = state_dict["_extra_state"]["model_params"]
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/infer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
- config: The Dict-like configuration with training options.
"""
# Model
state_dict = torch.load(model_ckpt, map_location=DEVICE)
state_dict = torch.load(model_ckpt, map_location=DEVICE, weights_only=True)
if "model" in state_dict:
state_dict = state_dict["model"]
model_params = state_dict["_extra_state"]["model_params"]
Expand Down
4 changes: 3 additions & 1 deletion deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,9 @@ def get_lr(lr_params):
optimizer_state_dict = None
if resuming:
log.info(f"Resuming from {resume_model}.")
state_dict = torch.load(resume_model, map_location=DEVICE)
state_dict = torch.load(
resume_model, map_location=DEVICE, weights_only=True
)
if "model" in state_dict:
optimizer_state_dict = (
state_dict["optimizer"] if finetune_model is None else None
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/utils/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def get_finetune_rules(
Fine-tuning rules in a dict format, with `model_branch`: FinetuneRuleItem pairs.
"""
multi_task = "model_dict" in model_config
state_dict = torch.load(finetune_model, map_location=env.DEVICE)
state_dict = torch.load(finetune_model, map_location=env.DEVICE, weights_only=True)
if "model" in state_dict:
state_dict = state_dict["model"]
last_model_params = state_dict["_extra_state"]["model_params"]
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def serialize_from_file(model_file: str) -> dict:
model = get_model(model_def_script)
model.load_state_dict(saved_model.state_dict())
elif model_file.endswith(".pt"):
state_dict = torch.load(model_file, map_location="cpu")
state_dict = torch.load(model_file, map_location="cpu", weights_only=True)
if "model" in state_dict:
state_dict = state_dict["model"]
model_def_script = state_dict["_extra_state"]["model_params"]
Expand Down
10 changes: 6 additions & 4 deletions source/tests/pt/model/test_descriptor_dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,13 +245,15 @@ def test_descriptor_block(self):
des = DescrptBlockSeAtten(
**dparams,
).to(env.DEVICE)
des.load_state_dict(torch.load(self.file_model_param))
des.load_state_dict(torch.load(self.file_model_param, weights_only=True))
coord = self.coord
atype = self.atype
box = self.cell
# handel type_embedding
type_embedding = TypeEmbedNet(ntypes, 8, use_tebd_bias=True).to(env.DEVICE)
type_embedding.load_state_dict(torch.load(self.file_type_embed))
type_embedding.load_state_dict(
torch.load(self.file_type_embed, weights_only=True)
)

## to save model parameters
# torch.save(des.state_dict(), 'model_weights.pth')
Expand Down Expand Up @@ -299,8 +301,8 @@ def test_descriptor(self):
**dparams,
).to(env.DEVICE)
target_dict = des.state_dict()
source_dict = torch.load(self.file_model_param)
type_embd_dict = torch.load(self.file_type_embed)
source_dict = torch.load(self.file_model_param, weights_only=True)
type_embd_dict = torch.load(self.file_type_embed, weights_only=True)
target_dict = translate_se_atten_and_type_embd_dicts_to_dpa1(
target_dict,
source_dict,
Expand Down
4 changes: 2 additions & 2 deletions source/tests/pt/model/test_descriptor_dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,10 @@ def test_descriptor(self):
**dparams,
).to(env.DEVICE)
target_dict = des.state_dict()
source_dict = torch.load(self.file_model_param)
source_dict = torch.load(self.file_model_param, weights_only=True)
# type_embd of repformer is removed
source_dict.pop("type_embedding.embedding.embedding_net.layers.0.bias")
type_embd_dict = torch.load(self.file_type_embed)
type_embd_dict = torch.load(self.file_type_embed, weights_only=True)
target_dict = translate_type_embd_dicts_to_dpa2(
target_dict,
source_dict,
Expand Down
4 changes: 3 additions & 1 deletion source/tests/pt/model/test_saveload_dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ def get_model_result(self, read=False, model_file="tmp_model.pt"):
optimizer = torch.optim.Adam(wrapper.parameters(), lr=self.start_lr)
optimizer.zero_grad()
if read:
wrapper.load_state_dict(torch.load(model_file, map_location=env.DEVICE))
wrapper.load_state_dict(
torch.load(model_file, map_location=env.DEVICE, weights_only=True)
)
os.remove(model_file)
else:
torch.save(wrapper.state_dict(), model_file)
Expand Down
4 changes: 3 additions & 1 deletion source/tests/pt/model/test_saveload_se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ def get_model_result(self, read=False, model_file="tmp_model.pt"):
optimizer = torch.optim.Adam(wrapper.parameters(), lr=self.start_lr)
optimizer.zero_grad()
if read:
wrapper.load_state_dict(torch.load(model_file, map_location=env.DEVICE))
wrapper.load_state_dict(
torch.load(model_file, map_location=env.DEVICE, weights_only=True)
)
os.remove(model_file)
else:
torch.save(wrapper.state_dict(), model_file)
Expand Down

0 comments on commit 911f41b

Please sign in to comment.