-
Notifications
You must be signed in to change notification settings - Fork 11
/
sample.py
149 lines (142 loc) · 5.3 KB
/
sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import argparse
from pathlib import Path
import numpy as np
from ampal.amino_acids import standard_amino_acids
from design_utils.sampling_utils import (
apply_temp_to_probs,
sample_with_multiprocessing,
save_as,
)
from design_utils.utils import (
extract_sequence_from_pred_matrix,
get_rotamer_codec,
load_datasetmap,
)
def main_sample(args):
# Set Random seed:
np.random.default_rng(seed=args.seed)
# Sanitise Paths:
args.path_to_pred_matrix = Path(args.path_to_pred_matrix)
args.path_to_datasetmap = Path(args.path_to_datasetmap)
assert (
args.path_to_pred_matrix.exists()
), f"Prediction Matrix file {args.path_to_pred_matrix} does not exist"
assert (
args.path_to_datasetmap.exists()
), f"Dataset Map file {args.path_to_datasetmap} does not exist"
# Load prediction matrix:
prediction_matrix = np.genfromtxt(
args.path_to_pred_matrix, delimiter=",", dtype=np.float64
)
# Load datasetmap
datasetmap = load_datasetmap(
args.path_to_datasetmap, is_old=args.support_old_datasetmap
)
# Apply temperature factor to prediction matrix:
if args.temperature != 1:
prediction_matrix = apply_temp_to_probs(prediction_matrix, t=args.temperature)
# Load codec:
if args.predict_rotamers:
# Get rotamer categories:
_, flat_categories = get_rotamer_codec()
# Get dictionary for 3 letter -> 1 letter conversion:
res_to_r = dict(zip(standard_amino_acids.values(), standard_amino_acids.keys()))
# Create flat categories of 1 letter amino acid for each of the 338 rotamers:
flat_categories = [res_to_r[res.split("_")[0]] for res in flat_categories]
# Extract dictionaries with sequences:
else:
_, flat_categories = None, None
(
pdb_to_sequence,
pdb_to_probability,
pdb_to_real_sequence,
_,
_,
) = extract_sequence_from_pred_matrix(
datasetmap,
prediction_matrix,
rotamers_categories=flat_categories,
old_datasetmap=args.support_old_datasetmap,
)
# # Select only 59 structures used for sampling:
# af2_benchmark_structures = ["1hq0A", "1a41A", "1ds1A", "1dvoA", "1g3pA",
# "1h70A", "1hxrA", "1jovA", "1l0sA", "1o7iA",
# "1uzkA", "1x8qA", "2bhuA", "2dyiA", "2imhA",
# "2j8kA", "2of3A", "2ra1A", "2v3gA", "2v3iA",
# "2w18A", "3cxbA", "3dadA", "3dkrA", "3e3vA",
# "3e4gA", "3e8tA", "3essA", "3giaA", "3gohA",
# "3hvmA", "3klkA", "3kluA", "3kstA", "3kyfA",
# "3maoA", "3o4pA", "3oajA", "3q1nA", "3rf0A",
# "3swgA", "3zbdA", "3zh4A", "4a6qA", "4ecoA",
# "4efpA", "4fcgA", "4fs7A", "4i1kA", "4le7A",
# "4m4dA", "4ozwA", "4wp6A", "4y5jA", "5b1rA",
# "5bufA", "5c12A", "5dicA", "6baqA"]
# pdb_codes = af2_benchmark_structures
pdb_codes = list(pdb_to_probability.keys())
print(
f"Ready to sample {args.sample_n} for each of the {len(pdb_codes)} proteins from {args.path_to_pred_matrix.stem}."
)
pdb_to_sample = sample_with_multiprocessing(
args.workers, pdb_codes, args.sample_n, pdb_to_probability, flat_categories
)
# Save sequences to files:
output_paths = save_as(
pdb_to_sample,
filename=f"{args.path_to_pred_matrix.stem}_temp_{args.temperature}_n_{args.sample_n}_{pdb_codes[0]}",
mode=args.save_as,
)
return output_paths
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="")
parser.add_argument(
"--path_to_pred_matrix",
type=str,
help="Path to prediction matrix file ending with .csv",
)
parser.add_argument(
"--path_to_datasetmap",
default="datasetmap.txt",
type=str,
help="Path to dataset map ending with .txt",
)
parser.add_argument(
"--predict_rotamers",
default=False,
action="store_true",
help="Whether model outputs predictions for 338 rotamers (True) or 20 residues (False).",
)
parser.add_argument(
"--sample_n",
type=int,
default=100,
help="Number of samples to be drawn from the distribution.",
)
parser.add_argument(
"--save_as",
type=str,
default="all",
const="all",
nargs="?",
choices=["fasta", "json", "all"],
help="Whether to save as fasta and json (default: all) or either of them.",
)
parser.add_argument(
"--workers", type=int, default=8, help="Number of workers to use (default: 8)"
)
parser.add_argument(
"--temperature",
type=float,
default=1,
help="Temperature factor to apply to softmax prediction. (default: 1.0 - unchanged)",
)
parser.add_argument(
"--support_old_datasetmap",
default=False,
action="store_true",
help="Whether model to import from the old datasetmap (default: False)",
)
parser.add_argument(
"--seed", type=int, default=42, help="random seed (default: 42)"
)
params = parser.parse_args()
main_sample(params)