-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathShow_Epoch.py
75 lines (49 loc) · 2.12 KB
/
Show_Epoch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import sys
sys.path.insert(0, './Modules/')
from build_encoding import read_decodings, decode
from rewards import evaluate_chem_mol
from rdkit.Chem import Draw
import rdkit.Chem as Chem
import matplotlib.pyplot as plt
import numpy as np
import argparse
from rdkit import rdBase
rdBase.DisableLog('rdApp.error')
def safe_decode(x, decodings):
try:
m = decode(x,decodings)
Chem.Kekulize(m)
return m
except:
return None
def main(epoch, savefile=None, imagefile=None):
decodings2 = read_decodings()
in_mols = np.load("History/in-{}.npy".format(epoch))
out_mols = np.load("History/out-{}.npy".format(epoch))
in_mols = [decode(m, decodings2) for m in in_mols]
out_mols = [safe_decode(m, decodings2) for m in out_mols]
use = [(not out_mols[i] is None) and \
Chem.MolToSmiles(out_mols[i]) != Chem.MolToSmiles(in_mols[i])
for i in range(len(out_mols))]
plot_mols = [[m1,m2] for m1,m2,u in zip(in_mols,out_mols,use) if u]
order = [np.sum(evaluate_chem_mol(out_mols[i])) for i in range(len(out_mols)) if use[i]]
plot_mols = [x for _,x in sorted(zip(order,plot_mols),key=lambda x:x[0],
reverse=True)]
plot_mols = [x for y in plot_mols for x in y ]
plot = Draw.MolsToGridImage(plot_mols[:50], molsPerRow=2)
if not imagefile is None:
plot.save(imagefile)
plot.show()
if not savefile is None:
with open(savefile, "w") as f:
f.write("Initial molecule ; Modified molecule\n")
for i in range(0,len(plot_mols), 2):
f.write(f'{Chem.MolToSmiles(plot_mols[i])} ; {Chem.MolToSmiles(plot_mols[i+1])}\n')
parser = argparse.ArgumentParser()
parser.add_argument("-SMILES", dest="SMILEFile", help="Save SMILE strings to file", default=None)
parser.add_argument("-epoch", dest="epoch", help="Epoch to display", required=True)
parser.add_argument("-image", dest="image", help="File to save image in", default=None)
if __name__ == "__main__":
args = parser.parse_args()
epoch = int(args.epoch)
main(int(args.epoch), args.SMILEFile, args.image)