diff --git a/example/README.md b/example/README.md index 37b6890..dbe7c2f 100644 --- a/example/README.md +++ b/example/README.md @@ -17,18 +17,6 @@ $ python run_zinc250k.py --datadir={YOUR_ZINC250K_DATASET_FOLDER} --train_mode={ You can switch to the SELFIES version by using flag `--version=selfies`, but the package `selfies` is required. -* To finetune a model, you need to prepare the dataset in the same format as below: -```csv -smiles,value -CC(=O)C3(C)CCC4C2C=C(C)C1=CC(=O)CCC1(C)C2CCC34C,-5.27 -CC(=O)OC3(CCC4C2C=C(C)C1=CC(=O)CCC1(C)C2CCC34C)C(C)=O,-5.35 -CN(C(=O)COc1nc2ccccc2s1)c3ccccc3,-4.873 -O=C(Nc1ccccc1)Nc2ccccc2,-3.15 -Clc1ccc(CN(C2CCCC2)C(=O)Nc3ccccc3)cc1,-5.915 -CC2(C)C1CCC(C)(C1)C2=O,-1.85 -CC1(C)C2CCC1(C)C(=O)C2,-1.96 - -``` ## JIT version? @@ -40,7 +28,7 @@ from bayesianflow_for_chem.data import smiles2vec from bayesianflow_for_chem.tool import sample, inpaint model = ChemBFN.from_checkpoint("YOUR_MODEL.pt").eval().to("cuda") -model = torch.jit.freeze(torch.jit.script(model), ["sample", "inpaint"]) +model = torch.jit.freeze(torch.jit.script(model), ["sample", "inpaint", "ode_sample", "ode_inpaint"]) # or model.compile() # ------- generate molecules ------- smiles = sample(model, 1, 60, 100) diff --git a/example/finetune.py b/example/finetune.py index 90e368f..8da28d2 100644 --- a/example/finetune.py +++ b/example/finetune.py @@ -9,6 +9,7 @@ import os import argparse from pathlib import Path +import torch import lightning as L from torch.utils.data import DataLoader from lightning.pytorch import loggers @@ -16,7 +17,7 @@ from bayesianflow_for_chem import ChemBFN, MLP from bayesianflow_for_chem.tool import test from bayesianflow_for_chem.train import Regressor -from bayesianflow_for_chem.data import collate, CSVData +from bayesianflow_for_chem.data import smiles2token, collate, CSVData cwd = Path(__file__).parent @@ -55,6 +56,13 @@ mlp = MLP([512, 256, args.ntask], dropout=args.dropout) regressor = Regressor(model, mlp, l_hparam) + +def encode(x): + smiles = x["smiles"][0] + value = x["value"] # set your own value tag! + return {"token": smiles2token(smiles), "value": torch.tensor(value)} + + checkpoint_callback = ModelCheckpoint(dirpath=workdir, monitor="val_loss") logger = loggers.TensorBoardLogger(logdir, args.name) trainer = L.Trainer( @@ -66,22 +74,24 @@ enable_progress_bar=False, ) -traindata = DataLoader( - CSVData(datadir / f"{args.name}_train.csv"), 32, True, collate_fn=collate -) -valdata = DataLoader(CSVData(datadir / f"{args.name}_val.csv"), 32, collate_fn=collate) -testdata = DataLoader( - CSVData(datadir / f"{args.name}_test.csv"), 32, collate_fn=collate -) +train_dataset = CSVData(datadir / f"{args.name}_train.csv") +train_dataset.map(encode) +train_dataloader = DataLoader(train_dataset, 32, True, collate_fn=collate) +val_dataset = CSVData(datadir / f"{args.name}_val.csv") +val_dataset.map(encode) +val_dataloader = DataLoader(val_dataset, 32, collate_fn=collate) +test_dataset = CSVData(datadir / f"{args.name}_test.csv") +test_dataset.map(encode) +test_dataloader = DataLoader(test_dataset, 32, collate_fn=collate) if __name__ == "__main__": os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64" - trainer.fit(regressor, traindata, valdata) + trainer.fit(regressor, train_dataloader, val_dataloader) regressor.export_model(workdir) - result = test(model, regressor.mlp, testdata, l_hparam["mode"]) + result = test(model, regressor.mlp, test_dataloader, l_hparam["mode"]) print("last:", result) regressor = Regressor.load_from_checkpoint( trainer.checkpoint_callback.best_model_path, model=model, mlp=mlp ) - result = test(regressor.model, regressor.mlp, testdata, l_hparam["mode"]) + result = test(regressor.model, regressor.mlp, test_dataloader, l_hparam["mode"]) print("best:", result) diff --git a/example/pretrain.py b/example/pretrain.py index 8d0a883..37861c3 100644 --- a/example/pretrain.py +++ b/example/pretrain.py @@ -4,7 +4,7 @@ pretraining. e.g., -$ python pretrain.py --nepoch=15 --datafile="./dataset/train.csv" --label_mode="none" +$ python pretrain.py --nepoch=15 --datafile="./dataset/train.csv" """ import os import argparse @@ -13,9 +13,9 @@ from torch.utils.data import DataLoader from lightning.pytorch import loggers from lightning.pytorch.callbacks import ModelCheckpoint -from bayesianflow_for_chem import ChemBFN, MLP +from bayesianflow_for_chem import ChemBFN from bayesianflow_for_chem.train import Model -from bayesianflow_for_chem.data import collate, CSVData, VOCAB_COUNT +from bayesianflow_for_chem.data import smiles2token, collate, CSVData, VOCAB_COUNT cwd = Path(__file__).parent @@ -23,27 +23,13 @@ parser = argparse.ArgumentParser() parser.add_argument("--datafile", default="./train.csv", type=str, help="dataset file") parser.add_argument("--nepoch", default=15, type=int, help="number of epochs") -parser.add_argument( - "--label_mode", - default="none", - type=str, - help="'none', 'class:x', or 'value:x' where x is the size of your guidance label", -) args = parser.parse_args() workdir = cwd / "pretrain" logdir = cwd / "log" -if args.label_mode.lower() == "none": - mlp = None -elif "class" in args.label_mode.lower(): - mlp = MLP([int(args.label_mode.split(":")[-1]), 256, 512], True) -elif "value" in args.label_mode.lower(): - mlp = MLP([int(args.label_mode.split(":")[-1]), 256, 512]) -else: - raise NotImplementedError -model = Model(ChemBFN(VOCAB_COUNT), mlp) +model = Model(ChemBFN(VOCAB_COUNT)) checkpoint_callback = ModelCheckpoint(dirpath=workdir, every_n_train_steps=1000) logger = loggers.TensorBoardLogger(logdir, "pretrain") trainer = L.Trainer( @@ -58,6 +44,8 @@ if __name__ == "__main__": os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64" - data = DataLoader(CSVData(args.datafile), 512, True, collate_fn=collate) + dataset = CSVData(args.datafile) + dataset.map(lambda x: {"token": smiles2token(".".join(x["smiles"]))}) + data = DataLoader(dataset, 512, True, collate_fn=collate) trainer.fit(model, data) model.export_model(workdir) diff --git a/example/run_moses.py b/example/run_moses.py index 091e841..33f9397 100644 --- a/example/run_moses.py +++ b/example/run_moses.py @@ -24,9 +24,9 @@ VOCAB_COUNT, collate, load_vocab, + smiles2token, split_selfies, CSVData, - BaseCSVDataClass, ) @@ -49,6 +49,7 @@ vocab_keys = VOCAB_KEYS dataset_file = args.datadir + "/train.csv" train_data = CSVData(dataset_file) + train_data.map(lambda x: {"token": smiles2token(".".join(x["smiles"]))}) else: import selfies @@ -78,16 +79,8 @@ def selfies2token(s): [1] + [vocab_dict[i] for i in split_selfies(s)] + [2], dtype=torch.long ) - class SELData(BaseCSVDataClass): - def __getitem__(self, idx) -> None: - if torch.is_tensor(idx): - idx = idx.tolist() - d = self.data[idx + 1].replace("\n", "").split(",") - s = ".".join([d[i] for i in self.selfies_idx if d[i] != ""]) - token = selfies2token(s) - return {"token": token} - - train_data = SELData(dataset_file) + train_data = CSVData(dataset_file) + train_data.map(lambda x: {"token": selfies2token(".".join(x["selfies"]))}) model = Model(ChemBFN(num_vocab)) checkpoint_callback = ModelCheckpoint(dirpath=workdir, every_n_train_steps=1000) diff --git a/example/run_zinc250k.py b/example/run_zinc250k.py index d9b9b61..4f69509 100644 --- a/example/run_zinc250k.py +++ b/example/run_zinc250k.py @@ -7,7 +7,6 @@ $ python run_zinc250k.py --version=smiles --train_mode=sar --target=fa7 --samplestep=1000 --datadir="./dataset/zinc250k" """ import os -import csv import json import argparse from pathlib import Path @@ -23,10 +22,10 @@ VOCAB_COUNT, VOCAB_KEYS, CSVData, - BaseCSVDataClass, collate, - split_selfies, load_vocab, + smiles2token, + split_selfies, ) parser = argparse.ArgumentParser() @@ -39,27 +38,25 @@ cwd = Path(__file__).parent targets = "parp1,fa7,5ht1b,braf,jak2".split(",") +assert args.target in targets dataset_file = f"{args.datadir}/zinc250k.csv" workdir = cwd / f"zinc250k_{args.train_mode}/{args.target}_{args.version}" logdir = cwd / "log" max_epochs = 100 l_hparam = {"lr": 5e-5, "lr_warmup_step": 1000, "uncond_prob": 0.2} -with open(dataset_file, "r") as f: - _data = list(csv.reader(f)) -if "value" not in _data[0]: - # format dataset - dataset_file = dataset_file.replace(".csv", "_formatted.csv") - _data[0] = ["smiles"] + ["value"] * 7 - with open(dataset_file, "w", newline="") as f: - writer = csv.writer(f) - writer.writerows(_data) - if args.version.lower() == "smiles": + + def encode(x): + smiles = x["smiles"][0] + value = [x["qed"], x["sa"], x[args.target]] + return {"token": smiles2token(smiles), "value": torch.tensor(value)} + pad_len = 111 num_vocab = VOCAB_COUNT vocab_keys = VOCAB_KEYS - train_data = CSVData(dataset_file, label_idx=[0, 1, targets.index(args.target) + 2]) + train_data = CSVData(dataset_file) + train_data.map(encode) else: import selfies @@ -97,27 +94,18 @@ def selfies2token(s): return torch.tensor( [1] + [vocab_dict[i] for i in split_selfies(s)] + [2], dtype=torch.long ) + + def encode(x): + s = x["selfies"][0] + value = [x["qed"], x["sa"], x[args.target]] + return {"token": selfies2token(s), "value": torch.tensor(value)} - class SELData(BaseCSVDataClass): - def __getitem__(self, idx) -> None: - if torch.is_tensor(idx): - idx = idx.tolist() - d = self.data[idx + 1].replace("\n", "").split(",") - values = [ - float(d[i]) if d[i].strip() != "" else torch.inf for i in self.value_idx - ] - s = ".".join([d[i] for i in self.selfies_idx if d[i] != ""]) - token = selfies2token(s) - out_dict = {"token": token} - if len(values) != 0: - out_dict["value"] = torch.tensor(values, dtype=torch.float32) - return out_dict - - train_data = SELData(dataset_file, label_idx=[0, 1, targets.index(args.target) + 2]) + train_data = CSVData(dataset_file) + train_data.map(encode) bfn = ChemBFN(num_vocab) mlp = MLP([3, 256, 512]) -model = Model(bfn, mlp, l_hparam) +model = Model(bfn, mlp, hparam=l_hparam) if args.train_mode == "normal": model.model.semi_autoregressive = False elif args.train_mode == "sar":