Skip to content

Commit

Permalink
fix issue with different sequences and order
Browse files Browse the repository at this point in the history
  • Loading branch information
gcorso committed Jul 8, 2024
1 parent a487ef8 commit 6d065cd
Showing 1 changed file with 22 additions and 31 deletions.
53 changes: 22 additions & 31 deletions utils/inference_utils.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,39 @@
import copy
import os
import pickle

import torch
from Bio.PDB import PDBParser
from esm import FastaBatchedDataset, pretrained
from rdkit.Chem import AddHs, MolFromSmiles
from torch_geometric.data import Dataset, HeteroData
import numpy as np
import torch
import prody as pr
import esm

from datasets.constants import three_to_one
from datasets.process_mols import generate_conformer, read_molecule, get_lig_graph_with_matching, moad_extract_receptor_structure
from datasets.parse_chi import aa_idx2aa_short, get_onehot_sequence


def get_sequences_from_pdbfile(file_path):
biopython_parser = PDBParser()
structure = biopython_parser.get_structure('random_id', file_path)
structure = structure[0]
sequence = None
for i, chain in enumerate(structure):
seq = ''
for res_idx, residue in enumerate(chain):
if residue.get_resname() == 'HOH':
continue
residue_coords = []
c_alpha, n, c = None, None, None
for atom in residue:
if atom.name == 'CA':
c_alpha = list(atom.get_vector())
if atom.name == 'N':
n = list(atom.get_vector())
if atom.name == 'C':
c = list(atom.get_vector())
if c_alpha != None and n != None and c != None: # only append residue if it is an amino acid
try:
seq += three_to_one[residue.get_resname()]
except Exception as e:
seq += '-'
print("encountered unknown AA: ", residue.get_resname(), ' in the complex. Replacing it with a dash - .')

pdb = pr.parsePDB(file_path)
seq = pdb.ca.getSequence()
one_hot = get_onehot_sequence(seq)

chain_ids = np.zeros(len(one_hot))
res_chain_ids = pdb.ca.getChids()
res_seg_ids = pdb.ca.getSegnames()
res_chain_ids = np.asarray([s + c for s, c in zip(res_seg_ids, res_chain_ids)])
ids = np.unique(res_chain_ids)

for i, id in enumerate(ids):
chain_ids[res_chain_ids == id] = i

s_temp = np.argmax(one_hot[res_chain_ids == id], axis=1)
s = ''.join([aa_idx2aa_short[aa_idx] for aa_idx in s_temp])

if sequence is None:
sequence = seq
sequence = s
else:
sequence += (":" + seq)
sequence += (":" + s)

return sequence

Expand Down

0 comments on commit 6d065cd

Please sign in to comment.