-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathfinetune.py
98 lines (85 loc) · 3.35 KB
/
finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# -*- coding: utf-8 -*-
# author: Nianze A. TAO (SUENO Omozawa)
"""
Fine-tuning.
e.g.,
$ python fintune.py --name=esol --nepoch=100 --datadir="./dataset/moleculenet" --ckpt="./ckpt/zinc15_40m.pt" --mode="regression" --dropout=0.0
"""
import os
import argparse
from pathlib import Path
import torch
import lightning as L
from torch.utils.data import DataLoader
from lightning.pytorch import loggers
from lightning.pytorch.callbacks import ModelCheckpoint
from bayesianflow_for_chem import ChemBFN, MLP
from bayesianflow_for_chem.tool import test
from bayesianflow_for_chem.train import Regressor
from bayesianflow_for_chem.data import smiles2token, collate, CSVData
cwd = Path(__file__).parent
parser = argparse.ArgumentParser()
parser.add_argument(
"--datadir", default="./moleculenet", type=str, help="dataset folder"
)
parser.add_argument(
"--ckpt", default="./ckpt/zinc15_40m.pt", type=str, help="ckpt file"
)
parser.add_argument("--name", default="esol", type=str, help="dataset name")
parser.add_argument("--nepoch", default=100, type=int, help="number of epochs")
# in most cases, --ntask=2 when --mode=classification and --ntask=1 when --mode=regression
parser.add_argument("--ntask", default=1, type=int, help="number of tasks")
parser.add_argument(
"--mode", default="regression", type=str, help="regression or classification"
)
parser.add_argument("--dropout", default=0.5, type=float, help="dropout rate")
args = parser.parse_args()
workdir = cwd / args.name
logdir = cwd / "log"
datadir = Path(args.datadir)
l_hparam = {
"mode": args.mode,
"lr_scheduler_factor": 0.8,
"lr_scheduler_patience": 20,
"lr_warmup_step": 1000 if args.mode == "regression" else 100,
"max_lr": 1e-4,
"freeze": False,
}
model = ChemBFN.from_checkpoint(args.ckpt)
mlp = MLP([512, 256, args.ntask], dropout=args.dropout)
regressor = Regressor(model, mlp, l_hparam)
def encode(x):
smiles = x["smiles"][0]
value = x["value"] # set your own value tag!
value = [float(i) if i != "" else torch.inf for i in value]
return {"token": smiles2token(smiles), "value": torch.tensor(value)}
checkpoint_callback = ModelCheckpoint(dirpath=workdir, monitor="val_loss")
logger = loggers.TensorBoardLogger(logdir, args.name)
trainer = L.Trainer(
max_epochs=args.nepoch,
log_every_n_steps=5,
logger=logger,
accelerator="gpu",
callbacks=[checkpoint_callback],
enable_progress_bar=False,
)
train_dataset = CSVData(datadir / f"{args.name}_train.csv")
train_dataset.map(encode)
train_dataloader = DataLoader(train_dataset, 32, True, collate_fn=collate)
val_dataset = CSVData(datadir / f"{args.name}_val.csv")
val_dataset.map(encode)
val_dataloader = DataLoader(val_dataset, 32, collate_fn=collate)
test_dataset = CSVData(datadir / f"{args.name}_test.csv")
test_dataset.map(encode)
test_dataloader = DataLoader(test_dataset, 32, collate_fn=collate)
if __name__ == "__main__":
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64"
trainer.fit(regressor, train_dataloader, val_dataloader)
regressor.export_model(workdir)
result = test(model, regressor.mlp, test_dataloader, l_hparam["mode"])
print("last:", result)
regressor = Regressor.load_from_checkpoint(
trainer.checkpoint_callback.best_model_path, model=model, mlp=mlp
)
result = test(regressor.model, regressor.mlp, test_dataloader, l_hparam["mode"])
print("best:", result)