Skip to content

Commit

Permalink
20/06/2024 Update
Browse files Browse the repository at this point in the history
* Updated example scripts.
* Added other-requirements.txt
  • Loading branch information
Augus1999 authored Jun 19, 2024
1 parent a9819b5 commit 3b25da8
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 6 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# ChemBFN: Bayesian Flow Network for Chemistry
9 changes: 6 additions & 3 deletions bayesianflow_for_chem/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,13 +356,13 @@ def reconstruction_loss(self, x: Tensor, t: Tensor, y: Optional[Tensor]) -> Tens
x, logits = torch.broadcast_tensors(x[..., None], logits)
return (-logits.gather(-1, x[..., :1]).squeeze(-1)).mean()

@torch.inference_mode()
@torch.jit.export
def sample(
self,
batch_size: int,
sequence_size: int,
y: Optional[Tensor],
sample_step: int = 1000,
sample_step: int = 100,
guidance_strength: float = 4.0,
) -> Tensor:
"""
Expand All @@ -375,11 +375,14 @@ def sample(
:param guidance_strength: strength of conditional generation. It is not used if y is null.
:return: probability distribution; shape: (n_b, n_t, n_vocab)
"""
self.eval()
theta = (
torch.ones((batch_size, sequence_size, self.K), device=self.beta.device)
/ self.K
)
if y is not None:
assert y.dim() == 3
if y.shape[0] == 1:
y = y.repeat(batch_size, 1, 1)
for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
t = (i - 1).view(1, 1).repeat(batch_size, 1) / sample_step
p = self.discrete_output_distribution(theta, t, y, guidance_strength)
Expand Down
3 changes: 2 additions & 1 deletion bayesianflow_for_chem/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,12 @@ def split_dataset(
writer.writerows([header] + val_set)


@torch.no_grad()
def sample(
model: ChemBFN,
batch_size: int,
sequence_size: int,
sample_step: int = 1000,
sample_step: int = 100,
y: Optional[Tensor] = None,
guidance_strength: float = 4.0,
device: Union[str, torch.device, None] = None,
Expand Down
2 changes: 1 addition & 1 deletion example/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ $ python run_moses.py --datadir={YOUR_MOSES_DATASET_FOLDER} --samplestep=100
$ python run_guacamol.py --datadir={YOUR_GUACAMOL_DATASET_FOLDER} --samplestep=100
```

You can switch to the SELFIES version by using flag `--version=selfies`, but the package `selfies` is required before hand.
You can switch to the SELFIES version by using flag `--version=selfies`, but the package `selfies` is required.
2 changes: 1 addition & 1 deletion example/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Fine-tuning.
e.g.,
$ python fintune.py --name=esol --nepoch=300 --datadir="./dataset/moleculenet" --ckpt="./ckpt/zinc15_40m.pt" --mode="regression"
$ python fintune.py --name=esol --nepoch=100 --datadir="./dataset/moleculenet" --ckpt="./ckpt/zinc15_40m.pt" --mode="regression"
"""
import os
import argparse
Expand Down
63 changes: 63 additions & 0 deletions example/pretrain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# -*- coding: utf-8 -*-
# author: Nianze A. TAO (SUENO Omozawa)
"""
pretraining.
e.g.,
$ python pretrain.py --nepoch=15 --datafile="./dataset/train.csv" --label_mode="none"
"""
import os
import argparse
from pathlib import Path
import lightning as L
from torch.utils.data import DataLoader
from lightning.pytorch import loggers
from lightning.pytorch.callbacks import ModelCheckpoint
from bayesianflow_for_chem import ChemBFN, MLP
from bayesianflow_for_chem.train import Model
from bayesianflow_for_chem.data import collate, CSVData, VOCAB_COUNT


cwd = Path(__file__).parent

parser = argparse.ArgumentParser()
parser.add_argument("--datafile", default="./train.csv", type=str, help="dataset file")
parser.add_argument("--nepoch", default=15, type=int, help="number of epochs")
parser.add_argument(
"--label_mode",
default="none",
type=str,
help="'none', 'class:x', or 'value:x' where x is the size of your guidance label",
)
args = parser.parse_args()

workdir = cwd / "pretrain"
logdir = cwd / "log"

if args.label_mode.lower() == "none":
mlp = None
elif "class" in args.label_mode.lower():
mlp = MLP([int(args.label_mode.split(":")[-1]), 256, 512], True)
elif "value" in args.label_mode.lower():
mlp = MLP([int(args.label_mode.split(":")[-1]), 256, 512])
else:
raise NotImplementedError

model = Model(ChemBFN(VOCAB_COUNT), mlp)
checkpoint_callback = ModelCheckpoint(dirpath=workdir, every_n_train_steps=1000)
logger = loggers.TensorBoardLogger(logdir, "pretrain")
trainer = L.Trainer(
max_epochs=args.nepoch,
log_every_n_steps=500,
logger=logger,
accelerator="gpu",
callbacks=[checkpoint_callback],
enable_progress_bar=False,
)


if __name__ == "__main__":
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64"
data = DataLoader(CSVData(args.datafile), 512, True, collate_fn=collate)
trainer.fit(model, data)
model.export_model(workdir)
4 changes: 4 additions & 0 deletions other-requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
moses>=1.0
selfies>=2.1.1
guacamol>=0.5.5
tensorboard>=2.16.0

0 comments on commit 3b25da8

Please sign in to comment.