diff --git a/README.md b/README.md new file mode 100644 index 0000000..8021e90 --- /dev/null +++ b/README.md @@ -0,0 +1 @@ +# ChemBFN: Bayesian Flow Network for Chemistry diff --git a/bayesianflow_for_chem/model.py b/bayesianflow_for_chem/model.py index 0a9a4db..e968a1e 100644 --- a/bayesianflow_for_chem/model.py +++ b/bayesianflow_for_chem/model.py @@ -356,13 +356,13 @@ def reconstruction_loss(self, x: Tensor, t: Tensor, y: Optional[Tensor]) -> Tens x, logits = torch.broadcast_tensors(x[..., None], logits) return (-logits.gather(-1, x[..., :1]).squeeze(-1)).mean() - @torch.inference_mode() + @torch.jit.export def sample( self, batch_size: int, sequence_size: int, y: Optional[Tensor], - sample_step: int = 1000, + sample_step: int = 100, guidance_strength: float = 4.0, ) -> Tensor: """ @@ -375,11 +375,14 @@ def sample( :param guidance_strength: strength of conditional generation. It is not used if y is null. :return: probability distribution; shape: (n_b, n_t, n_vocab) """ - self.eval() theta = ( torch.ones((batch_size, sequence_size, self.K), device=self.beta.device) / self.K ) + if y is not None: + assert y.dim() == 3 + if y.shape[0] == 1: + y = y.repeat(batch_size, 1, 1) for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device): t = (i - 1).view(1, 1).repeat(batch_size, 1) / sample_step p = self.discrete_output_distribution(theta, t, y, guidance_strength) diff --git a/bayesianflow_for_chem/tool.py b/bayesianflow_for_chem/tool.py index 1ff4e53..93eeb13 100644 --- a/bayesianflow_for_chem/tool.py +++ b/bayesianflow_for_chem/tool.py @@ -141,11 +141,12 @@ def split_dataset( writer.writerows([header] + val_set) +@torch.no_grad() def sample( model: ChemBFN, batch_size: int, sequence_size: int, - sample_step: int = 1000, + sample_step: int = 100, y: Optional[Tensor] = None, guidance_strength: float = 4.0, device: Union[str, torch.device, None] = None, diff --git a/example/README.md b/example/README.md index 895edd0..ac45f74 100644 --- a/example/README.md +++ b/example/README.md @@ -9,4 +9,4 @@ $ python run_moses.py --datadir={YOUR_MOSES_DATASET_FOLDER} --samplestep=100 $ python run_guacamol.py --datadir={YOUR_GUACAMOL_DATASET_FOLDER} --samplestep=100 ``` -You can switch to the SELFIES version by using flag `--version=selfies`, but the package `selfies` is required before hand. \ No newline at end of file +You can switch to the SELFIES version by using flag `--version=selfies`, but the package `selfies` is required. \ No newline at end of file diff --git a/example/finetune.py b/example/finetune.py index b3972c8..2707262 100644 --- a/example/finetune.py +++ b/example/finetune.py @@ -4,7 +4,7 @@ Fine-tuning. e.g., -$ python fintune.py --name=esol --nepoch=300 --datadir="./dataset/moleculenet" --ckpt="./ckpt/zinc15_40m.pt" --mode="regression" +$ python fintune.py --name=esol --nepoch=100 --datadir="./dataset/moleculenet" --ckpt="./ckpt/zinc15_40m.pt" --mode="regression" """ import os import argparse diff --git a/example/pretrain.py b/example/pretrain.py index e69de29..8d0a883 100644 --- a/example/pretrain.py +++ b/example/pretrain.py @@ -0,0 +1,63 @@ +# -*- coding: utf-8 -*- +# author: Nianze A. TAO (SUENO Omozawa) +""" +pretraining. + +e.g., +$ python pretrain.py --nepoch=15 --datafile="./dataset/train.csv" --label_mode="none" +""" +import os +import argparse +from pathlib import Path +import lightning as L +from torch.utils.data import DataLoader +from lightning.pytorch import loggers +from lightning.pytorch.callbacks import ModelCheckpoint +from bayesianflow_for_chem import ChemBFN, MLP +from bayesianflow_for_chem.train import Model +from bayesianflow_for_chem.data import collate, CSVData, VOCAB_COUNT + + +cwd = Path(__file__).parent + +parser = argparse.ArgumentParser() +parser.add_argument("--datafile", default="./train.csv", type=str, help="dataset file") +parser.add_argument("--nepoch", default=15, type=int, help="number of epochs") +parser.add_argument( + "--label_mode", + default="none", + type=str, + help="'none', 'class:x', or 'value:x' where x is the size of your guidance label", +) +args = parser.parse_args() + +workdir = cwd / "pretrain" +logdir = cwd / "log" + +if args.label_mode.lower() == "none": + mlp = None +elif "class" in args.label_mode.lower(): + mlp = MLP([int(args.label_mode.split(":")[-1]), 256, 512], True) +elif "value" in args.label_mode.lower(): + mlp = MLP([int(args.label_mode.split(":")[-1]), 256, 512]) +else: + raise NotImplementedError + +model = Model(ChemBFN(VOCAB_COUNT), mlp) +checkpoint_callback = ModelCheckpoint(dirpath=workdir, every_n_train_steps=1000) +logger = loggers.TensorBoardLogger(logdir, "pretrain") +trainer = L.Trainer( + max_epochs=args.nepoch, + log_every_n_steps=500, + logger=logger, + accelerator="gpu", + callbacks=[checkpoint_callback], + enable_progress_bar=False, +) + + +if __name__ == "__main__": + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64" + data = DataLoader(CSVData(args.datafile), 512, True, collate_fn=collate) + trainer.fit(model, data) + model.export_model(workdir) diff --git a/other-requirements.txt b/other-requirements.txt new file mode 100644 index 0000000..344027c --- /dev/null +++ b/other-requirements.txt @@ -0,0 +1,4 @@ +moses>=1.0 +selfies>=2.1.1 +guacamol>=0.5.5 +tensorboard>=2.16.0 \ No newline at end of file