Skip to content

Commit

Permalink
11/06/2024 Update
Browse files Browse the repository at this point in the history
* Updated dataset class.
* Updated example scripts.
  • Loading branch information
Augus1999 authored Jun 11, 2024
1 parent 79068bc commit a9819b5
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 6 deletions.
2 changes: 2 additions & 0 deletions bayesianflow_for_chem/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ def __getitem__(self, idx: Union[int, Tensor]) -> Dict[str, Dict[str, Tensor]]:
values = [
float(d[i]) if d[i].strip() != "" else torch.inf for i in self.value_idx
]
if self.label_idx:
values = [values[i] for i in self.label_idx]
token = smiles2token(smiles)
if len(values) != 0:
value = torch.tensor(values, dtype=torch.float32)
Expand Down
1 change: 1 addition & 0 deletions bayesianflow_for_chem/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def __init__(
self.mlp = mlp
self.model.requires_grad_(not hparam["freeze"])
self.save_hyperparameters(hparam, ignore=["model", "mlp"])
assert hparam["mode"] in ("regression", "classification")

@staticmethod
def _mask_label(label: Tensor) -> Tuple[Tensor, Tensor]:
Expand Down
8 changes: 3 additions & 5 deletions example/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
parser.add_argument("--name", default="esol", type=str, help="dataset name")
parser.add_argument("--nepoch", default=100, type=int, help="number of epochs")
parser.add_argument("--ntask", default=1, type=int, help="number of tasks")
parser.add_argument(
"--mode", default="regression", type=str, help="regression or classification"
)
Expand All @@ -49,7 +50,7 @@
}

model = ChemBFN.from_checkpoint(args.ckpt)
mlp = MLP([512, 256, 1])
mlp = MLP([512, 256, args.ntask])
regressor = Regressor(model, mlp, l_hparam)

checkpoint_callback = ModelCheckpoint(dirpath=workdir, monitor="val_loss")
Expand All @@ -74,9 +75,6 @@
if __name__ == "__main__":
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64"
trainer.fit(regressor, traindata, valdata)
regressor = Regressor.load_from_checkpoint(
trainer.checkpoint_callback.best_model_path, model=model, mlp=mlp
)
regressor.export_model(workdir)
result = test(regressor.model, regressor.mlp, testdata, l_hparam["mode"])
result = test(model, regressor.mlp, testdata, l_hparam["mode"])
print(result)
3 changes: 2 additions & 1 deletion example/run_guacamol.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def generate(self, number_samples: int):
assess_distribution_learning(
generator,
chembl_training_file=args.datadir + "/guacamol_v1_train.smiles",
json_output_file=cwd / f"guacamol_sample_{i}_metrics.json",
json_output_file=cwd
/ f"guacamol_{args.version}_sample_{i}_metrics_samplestep_{args.samplestep}.json",
benchmark_version="v2",
)

0 comments on commit a9819b5

Please sign in to comment.