Skip to content

Commit

Permalink
Update example scripts.
Browse files Browse the repository at this point in the history
  • Loading branch information
Augus1999 authored Jan 20, 2025
1 parent 6048478 commit e25c06e
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 85 deletions.
14 changes: 1 addition & 13 deletions example/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,6 @@ $ python run_zinc250k.py --datadir={YOUR_ZINC250K_DATASET_FOLDER} --train_mode={

You can switch to the SELFIES version by using flag `--version=selfies`, but the package `selfies` is required.

* To finetune a model, you need to prepare the dataset in the same format as below:
```csv
smiles,value
CC(=O)C3(C)CCC4C2C=C(C)C1=CC(=O)CCC1(C)C2CCC34C,-5.27
CC(=O)OC3(CCC4C2C=C(C)C1=CC(=O)CCC1(C)C2CCC34C)C(C)=O,-5.35
CN(C(=O)COc1nc2ccccc2s1)c3ccccc3,-4.873
O=C(Nc1ccccc1)Nc2ccccc2,-3.15
Clc1ccc(CN(C2CCCC2)C(=O)Nc3ccccc3)cc1,-5.915
CC2(C)C1CCC(C)(C1)C2=O,-1.85
CC1(C)C2CCC1(C)C(=O)C2,-1.96
```

## JIT version?

Expand All @@ -40,7 +28,7 @@ from bayesianflow_for_chem.data import smiles2vec
from bayesianflow_for_chem.tool import sample, inpaint

model = ChemBFN.from_checkpoint("YOUR_MODEL.pt").eval().to("cuda")
model = torch.jit.freeze(torch.jit.script(model), ["sample", "inpaint"])
model = torch.jit.freeze(torch.jit.script(model), ["sample", "inpaint", "ode_sample", "ode_inpaint"])
# or model.compile()
# ------- generate molecules -------
smiles = sample(model, 1, 60, 100)
Expand Down
32 changes: 21 additions & 11 deletions example/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
import os
import argparse
from pathlib import Path
import torch
import lightning as L
from torch.utils.data import DataLoader
from lightning.pytorch import loggers
from lightning.pytorch.callbacks import ModelCheckpoint
from bayesianflow_for_chem import ChemBFN, MLP
from bayesianflow_for_chem.tool import test
from bayesianflow_for_chem.train import Regressor
from bayesianflow_for_chem.data import collate, CSVData
from bayesianflow_for_chem.data import smiles2token, collate, CSVData


cwd = Path(__file__).parent
Expand Down Expand Up @@ -55,6 +56,13 @@
mlp = MLP([512, 256, args.ntask], dropout=args.dropout)
regressor = Regressor(model, mlp, l_hparam)


def encode(x):
smiles = x["smiles"][0]
value = x["value"] # set your own value tag!
return {"token": smiles2token(smiles), "value": torch.tensor(value)}


checkpoint_callback = ModelCheckpoint(dirpath=workdir, monitor="val_loss")
logger = loggers.TensorBoardLogger(logdir, args.name)
trainer = L.Trainer(
Expand All @@ -66,22 +74,24 @@
enable_progress_bar=False,
)

traindata = DataLoader(
CSVData(datadir / f"{args.name}_train.csv"), 32, True, collate_fn=collate
)
valdata = DataLoader(CSVData(datadir / f"{args.name}_val.csv"), 32, collate_fn=collate)
testdata = DataLoader(
CSVData(datadir / f"{args.name}_test.csv"), 32, collate_fn=collate
)
train_dataset = CSVData(datadir / f"{args.name}_train.csv")
train_dataset.map(encode)
train_dataloader = DataLoader(train_dataset, 32, True, collate_fn=collate)
val_dataset = CSVData(datadir / f"{args.name}_val.csv")
val_dataset.map(encode)
val_dataloader = DataLoader(val_dataset, 32, collate_fn=collate)
test_dataset = CSVData(datadir / f"{args.name}_test.csv")
test_dataset.map(encode)
test_dataloader = DataLoader(test_dataset, 32, collate_fn=collate)

if __name__ == "__main__":
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64"
trainer.fit(regressor, traindata, valdata)
trainer.fit(regressor, train_dataloader, val_dataloader)
regressor.export_model(workdir)
result = test(model, regressor.mlp, testdata, l_hparam["mode"])
result = test(model, regressor.mlp, test_dataloader, l_hparam["mode"])
print("last:", result)
regressor = Regressor.load_from_checkpoint(
trainer.checkpoint_callback.best_model_path, model=model, mlp=mlp
)
result = test(regressor.model, regressor.mlp, testdata, l_hparam["mode"])
result = test(regressor.model, regressor.mlp, test_dataloader, l_hparam["mode"])
print("best:", result)
26 changes: 7 additions & 19 deletions example/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
pretraining.
e.g.,
$ python pretrain.py --nepoch=15 --datafile="./dataset/train.csv" --label_mode="none"
$ python pretrain.py --nepoch=15 --datafile="./dataset/train.csv"
"""
import os
import argparse
Expand All @@ -13,37 +13,23 @@
from torch.utils.data import DataLoader
from lightning.pytorch import loggers
from lightning.pytorch.callbacks import ModelCheckpoint
from bayesianflow_for_chem import ChemBFN, MLP
from bayesianflow_for_chem import ChemBFN
from bayesianflow_for_chem.train import Model
from bayesianflow_for_chem.data import collate, CSVData, VOCAB_COUNT
from bayesianflow_for_chem.data import smiles2token, collate, CSVData, VOCAB_COUNT


cwd = Path(__file__).parent

parser = argparse.ArgumentParser()
parser.add_argument("--datafile", default="./train.csv", type=str, help="dataset file")
parser.add_argument("--nepoch", default=15, type=int, help="number of epochs")
parser.add_argument(
"--label_mode",
default="none",
type=str,
help="'none', 'class:x', or 'value:x' where x is the size of your guidance label",
)
args = parser.parse_args()

workdir = cwd / "pretrain"
logdir = cwd / "log"

if args.label_mode.lower() == "none":
mlp = None
elif "class" in args.label_mode.lower():
mlp = MLP([int(args.label_mode.split(":")[-1]), 256, 512], True)
elif "value" in args.label_mode.lower():
mlp = MLP([int(args.label_mode.split(":")[-1]), 256, 512])
else:
raise NotImplementedError

model = Model(ChemBFN(VOCAB_COUNT), mlp)
model = Model(ChemBFN(VOCAB_COUNT))
checkpoint_callback = ModelCheckpoint(dirpath=workdir, every_n_train_steps=1000)
logger = loggers.TensorBoardLogger(logdir, "pretrain")
trainer = L.Trainer(
Expand All @@ -58,6 +44,8 @@

if __name__ == "__main__":
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64"
data = DataLoader(CSVData(args.datafile), 512, True, collate_fn=collate)
dataset = CSVData(args.datafile)
dataset.map(lambda x: {"token": smiles2token(".".join(x["smiles"]))})
data = DataLoader(dataset, 512, True, collate_fn=collate)
trainer.fit(model, data)
model.export_model(workdir)
15 changes: 4 additions & 11 deletions example/run_moses.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
VOCAB_COUNT,
collate,
load_vocab,
smiles2token,
split_selfies,
CSVData,
BaseCSVDataClass,
)


Expand All @@ -49,6 +49,7 @@
vocab_keys = VOCAB_KEYS
dataset_file = args.datadir + "/train.csv"
train_data = CSVData(dataset_file)
train_data.map(lambda x: {"token": smiles2token(".".join(x["smiles"]))})
else:
import selfies

Expand Down Expand Up @@ -78,16 +79,8 @@ def selfies2token(s):
[1] + [vocab_dict[i] for i in split_selfies(s)] + [2], dtype=torch.long
)

class SELData(BaseCSVDataClass):
def __getitem__(self, idx) -> None:
if torch.is_tensor(idx):
idx = idx.tolist()
d = self.data[idx + 1].replace("\n", "").split(",")
s = ".".join([d[i] for i in self.selfies_idx if d[i] != ""])
token = selfies2token(s)
return {"token": token}

train_data = SELData(dataset_file)
train_data = CSVData(dataset_file)
train_data.map(lambda x: {"token": selfies2token(".".join(x["selfies"]))})

model = Model(ChemBFN(num_vocab))
checkpoint_callback = ModelCheckpoint(dirpath=workdir, every_n_train_steps=1000)
Expand Down
50 changes: 19 additions & 31 deletions example/run_zinc250k.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
$ python run_zinc250k.py --version=smiles --train_mode=sar --target=fa7 --samplestep=1000 --datadir="./dataset/zinc250k"
"""
import os
import csv
import json
import argparse
from pathlib import Path
Expand All @@ -23,10 +22,10 @@
VOCAB_COUNT,
VOCAB_KEYS,
CSVData,
BaseCSVDataClass,
collate,
split_selfies,
load_vocab,
smiles2token,
split_selfies,
)

parser = argparse.ArgumentParser()
Expand All @@ -39,27 +38,25 @@

cwd = Path(__file__).parent
targets = "parp1,fa7,5ht1b,braf,jak2".split(",")
assert args.target in targets
dataset_file = f"{args.datadir}/zinc250k.csv"
workdir = cwd / f"zinc250k_{args.train_mode}/{args.target}_{args.version}"
logdir = cwd / "log"
max_epochs = 100
l_hparam = {"lr": 5e-5, "lr_warmup_step": 1000, "uncond_prob": 0.2}

with open(dataset_file, "r") as f:
_data = list(csv.reader(f))
if "value" not in _data[0]:
# format dataset
dataset_file = dataset_file.replace(".csv", "_formatted.csv")
_data[0] = ["smiles"] + ["value"] * 7
with open(dataset_file, "w", newline="") as f:
writer = csv.writer(f)
writer.writerows(_data)

if args.version.lower() == "smiles":

def encode(x):
smiles = x["smiles"][0]
value = [x["qed"], x["sa"], x[args.target]]
return {"token": smiles2token(smiles), "value": torch.tensor(value)}

pad_len = 111
num_vocab = VOCAB_COUNT
vocab_keys = VOCAB_KEYS
train_data = CSVData(dataset_file, label_idx=[0, 1, targets.index(args.target) + 2])
train_data = CSVData(dataset_file)
train_data.map(encode)
else:
import selfies

Expand Down Expand Up @@ -97,27 +94,18 @@ def selfies2token(s):
return torch.tensor(
[1] + [vocab_dict[i] for i in split_selfies(s)] + [2], dtype=torch.long
)

def encode(x):
s = x["selfies"][0]
value = [x["qed"], x["sa"], x[args.target]]
return {"token": selfies2token(s), "value": torch.tensor(value)}

class SELData(BaseCSVDataClass):
def __getitem__(self, idx) -> None:
if torch.is_tensor(idx):
idx = idx.tolist()
d = self.data[idx + 1].replace("\n", "").split(",")
values = [
float(d[i]) if d[i].strip() != "" else torch.inf for i in self.value_idx
]
s = ".".join([d[i] for i in self.selfies_idx if d[i] != ""])
token = selfies2token(s)
out_dict = {"token": token}
if len(values) != 0:
out_dict["value"] = torch.tensor(values, dtype=torch.float32)
return out_dict

train_data = SELData(dataset_file, label_idx=[0, 1, targets.index(args.target) + 2])
train_data = CSVData(dataset_file)
train_data.map(encode)

bfn = ChemBFN(num_vocab)
mlp = MLP([3, 256, 512])
model = Model(bfn, mlp, l_hparam)
model = Model(bfn, mlp, hparam=l_hparam)
if args.train_mode == "normal":
model.model.semi_autoregressive = False
elif args.train_mode == "sar":
Expand Down

0 comments on commit e25c06e

Please sign in to comment.