diff --git a/deepmd/pt/loss/loss.py b/deepmd/pt/loss/loss.py index 1a091e074e..5447c8735b 100644 --- a/deepmd/pt/loss/loss.py +++ b/deepmd/pt/loss/loss.py @@ -9,9 +9,12 @@ from deepmd.utils.data import ( DataRequirementItem, ) +from deepmd.utils.plugin import ( + make_plugin_registry, +) -class TaskLoss(torch.nn.Module, ABC): +class TaskLoss(torch.nn.Module, ABC, make_plugin_registry("loss")): def __init__(self, **kwargs): """Construct loss.""" super().__init__() @@ -38,3 +41,23 @@ def display_if_exist(loss: torch.Tensor, find_property: float) -> torch.Tensor: whether the property is found """ return loss if bool(find_property) else torch.nan + + @classmethod + def get_loss(cls, loss_params: dict) -> "TaskLoss": + """Get the loss module by the parameters. + + By default, all the parameters are directly passed to the constructor. + If not, override this method. + + Parameters + ---------- + loss_params : dict + The loss parameters + + Returns + ------- + TaskLoss + The loss module + """ + loss = cls(**loss_params) + return loss diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 211e1ba564..466080d34c 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -28,6 +28,7 @@ EnergySpinLoss, EnergyStdLoss, PropertyLoss, + TaskLoss, TensorLoss, ) from deepmd.pt.model.model import ( @@ -1260,7 +1261,8 @@ def get_loss(loss_params, start_lr, _ntypes, _model): loss_params["task_dim"] = task_dim return PropertyLoss(**loss_params) else: - raise NotImplementedError + loss_params["starter_learning_rate"] = start_lr + return TaskLoss.get_class_by_type(loss_type).get_loss(loss_params) def get_single_model(