diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index b1e4907da9..481b612557 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -185,6 +185,7 @@ def get_dataloader_and_buffer(_data, _params): if dist.is_available() else 0, # setting to 0 diverges the behavior of its iterator; should be >=1 drop_last=False, + collate_fn=lambda batch: batch, # prevent extra conversion pin_memory=True, ) with torch.device("cpu"): @@ -284,10 +285,6 @@ def get_lr(lr_params): } # Model - dp_random.seed(training_params["seed"]) - if training_params["seed"] is not None: - torch.manual_seed(training_params["seed"]) - self.model = get_model_for_wrapper(model_params, loss_param_tmp) # Loss @@ -312,7 +309,6 @@ def get_lr(lr_params): ) # Data - dp_random.seed(training_params["seed"]) if not self.multi_task: self.get_sample_func = single_model_stat( self.model, @@ -1116,7 +1112,7 @@ def get_data(self, is_train=True, task_key="Default"): batch_data = next(iter(self.validation_data[task_key])) for key in batch_data.keys(): - if key == "sid" or key == "fid" or key == "box": + if key == "sid" or key == "fid" or key == "box" or "find_" in key: continue elif not isinstance(batch_data[key], list): if batch_data[key] is not None: