Skip to content

Commit

Permalink
Update training.py
Browse files Browse the repository at this point in the history
Signed-off-by: Anchor Yu <[email protected]>
  • Loading branch information
1azyking authored Oct 23, 2024
1 parent fefe066 commit f456699
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def get_dataloader_and_buffer(_data, _params):
if dist.is_available()
else 0, # setting to 0 diverges the behavior of its iterator; should be >=1
drop_last=False,
collate_fn=lambda batch: batch, # prevent extra conversion
pin_memory=True,
)
with torch.device("cpu"):
Expand Down Expand Up @@ -284,10 +285,6 @@ def get_lr(lr_params):
}

# Model
dp_random.seed(training_params["seed"])
if training_params["seed"] is not None:
torch.manual_seed(training_params["seed"])

self.model = get_model_for_wrapper(model_params, loss_param_tmp)

# Loss
Expand All @@ -312,7 +309,6 @@ def get_lr(lr_params):
)

# Data
dp_random.seed(training_params["seed"])
if not self.multi_task:
self.get_sample_func = single_model_stat(
self.model,
Expand Down Expand Up @@ -1116,7 +1112,7 @@ def get_data(self, is_train=True, task_key="Default"):
batch_data = next(iter(self.validation_data[task_key]))

for key in batch_data.keys():
if key == "sid" or key == "fid" or key == "box":
if key == "sid" or key == "fid" or key == "box" or "find_" in key:
continue
elif not isinstance(batch_data[key], list):
if batch_data[key] is not None:
Expand Down

0 comments on commit f456699

Please sign in to comment.