diff --git a/drivers/py/pes/pet.py b/drivers/py/pes/pet.py index 4a5234782..44e45d472 100644 --- a/drivers/py/pes/pet.py +++ b/drivers/py/pes/pet.py @@ -26,11 +26,13 @@ class PET_driver(Dummy_driver): def __init__(self, args=None, verbose=False): self.error_msg = """ -The PET driver requires specification of a .json model file fitted with - the PET tools, and a template file that describes the chemical makeup of -the structure. +The PET driver requires (a) a path to the results/experiment_name folder emitted by pet_train + (b) a path to an ase.io.read-able file with a prototype structure -Example: python driver.py -m pet -u -o model.json,template.xyz +Other arguments to the pet.SingleStructCalculator class can be optionally +supplied in key=value form after the required arguments. + +Example: python driver.py -m pet -u -o "path/to/results/name,template.xyz,device=cuda" """ super().__init__(args, verbose) @@ -43,11 +45,17 @@ def check_arguments(self): This loads the potential and atoms template in PET """ - arglist = self.args + args = self.args + + if len(args) >= 2: + self.model_path = args[0] + self.template = args[1] + kwargs = {} + if len(args) > 2: + for arg in args[2:]: + key, value = arg.split("=") + kwargs[key] = value - if len(arglist) == 2: - self.model_path = arglist[0] - self.template = arglist[1] else: sys.exit(self.error_msg) @@ -55,7 +63,7 @@ def check_arguments(self): self.template_ase.arrays["forces"] = np.zeros_like(self.template_ase.positions) self.pet_calc = PETCalc( self.model_path, - default_hypers_path=self.model_path + "/default_hypers.yaml", + **kwargs, ) def __call__(self, cell, pos):