Skip to content

Commit

Permalink
Merge pull request #408 from rostro36/main
Browse files Browse the repository at this point in the history
Add use_custom_templates option
  • Loading branch information
jnwei authored Jul 18, 2024
2 parents f37d0d9 + 7d22739 commit c48f850
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 60 deletions.
13 changes: 7 additions & 6 deletions docs/source/Inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ Some commonly used command line flags are here. A full list of flags can be view
- `--data_random_seed`: Specifies a random seed to use.
- `--save_outputs`: Saves a copy of all outputs from the model, e.g. the output of the msa track, ptm heads.
- `--experiment_config_json`: Specify configuration settings using a json file. For example, passing a json with `{globals.relax.max_iterations = 10}` specifies 10 as the maximum number of relaxation iterations. See for [`openfold/config.py`](https://github.com/aqlaboratory/openfold/blob/main/openfold/config.py#L283) the full dictionary of configuration settings. Any parameters that are not manually set in these configuration settings will refer to the defaults specified by your `config_preset`.
- `--use_custom_template`: Uses all .cif files in `template_mmcif_dir` as template input. Make sure the chains of interest have the identifier _A_ and have the same length as the input sequence. The same templates will be read for all sequences that are passed for inference.


### Advanced Options for Increasing Efficiency
Expand All @@ -159,12 +160,12 @@ Note that chunking (as defined in section 1.11.8 of the AlphaFold 2 supplement)
#### Long sequence inference
To minimize memory usage during inference on long sequences, consider the following changes:

- As noted in the AlphaFold-Multimer paper, the AlphaFold/OpenFold template stack is a major memory bottleneck for inference on long sequences. OpenFold supports two mutually exclusive inference modes to address this issue. One, `average_templates` in the `template` section of the config, is similar to the solution offered by AlphaFold-Multimer, which is simply to average individual template representations. Our version is modified slightly to accommodate weights trained using the standard template algorithm. Using said weights, we notice no significant difference in performance between our averaged template embeddings and the standard ones. The second, `offload_templates`, temporarily offloads individual template embeddings into CPU memory. The former is an approximation while the latter is slightly slower; both are memory-efficient and allow the model to utilize arbitrarily many templates across sequence lengths. Both are disabled by default, and it is up to the user to determine which best suits their needs, if either.
- Inference-time low-memory attention (LMA) can be enabled in the model config. This setting trades off speed for vastly improved memory usage. By default, LMA is run with query and key chunk sizes of 1024 and 4096, respectively. These represent a favorable tradeoff in most memory-constrained cases. Powerusers can choose to tweak these settings in `openfold/model/primitives.py`. For more information on the LMA algorithm, see the aforementioned Staats & Rabe preprint.
- Disable `tune_chunk_size` for long sequences. Past a certain point, it only wastes time.
- As a last resort, consider enabling `offload_inference`. This enables more extensive CPU offloading at various bottlenecks throughout the model.
- As noted in the AlphaFold-Multimer paper, the AlphaFold/OpenFold template stack is a major memory bottleneck for inference on long sequences. OpenFold supports two mutually exclusive inference modes to address this issue. One, `average_templates` in the `template` section of the config, is similar to the solution offered by AlphaFold-Multimer, which is simply to average individual template representations. Our version is modified slightly to accommodate weights trained using the standard template algorithm. Using said weights, we notice no significant difference in performance between our averaged template embeddings and the standard ones. The second, `offload_templates`, temporarily offloads individual template embeddings into CPU memory. The former is an approximation while the latter is slightly slower; both are memory-efficient and allow the model to utilize arbitrarily many templates across sequence lengths. Both are disabled by default, and it is up to the user to determine which best suits their needs, if either.
- Inference-time low-memory attention (LMA) can be enabled in the model config. This setting trades off speed for vastly improved memory usage. By default, LMA is run with query and key chunk sizes of 1024 and 4096, respectively. These represent a favorable tradeoff in most memory-constrained cases. Powerusers can choose to tweak these settings in `openfold/model/primitives.py`. For more information on the LMA algorithm, see the aforementioned Staats & Rabe preprint.
- Disable `tune_chunk_size` for long sequences. Past a certain point, it only wastes time.
- As a last resort, consider enabling `offload_inference`. This enables more extensive CPU offloading at various bottlenecks throughout the model.
- Disable FlashAttention, which seems unstable on long sequences.

Using the most conservative settings, we were able to run inference on a 4600-residue complex with a single A100. Compared to AlphaFold's own memory offloading mode, ours is considerably faster; the same complex takes the more efficent AlphaFold-Multimer more than double the time. Use the `long_sequence_inference` config option to enable all of these interventions at once. The `run_pretrained_openfold.py` script can enable this config option with the `--long_sequence_inference` command line option
Using the most conservative settings, we were able to run inference on a 4600-residue complex with a single A100. Compared to AlphaFold's own memory offloading mode, ours is considerably faster; the same complex takes the more efficent AlphaFold-Multimer more than double the time. Use the `long_sequence_inference` config option to enable all of these interventions at once. The `run_pretrained_openfold.py` script can enable this config option with the `--long_sequence_inference` command line option

Input FASTA files containing multiple sequences are treated as complexes. In this case, the inference script runs AlphaFold-Gap, a hack proposed [here](https://twitter.com/minkbaek/status/1417538291709071362?lang=en), using the specified stock AlphaFold/OpenFold parameters (NOT AlphaFold-Multimer).
Input FASTA files containing multiple sequences are treated as complexes. In this case, the inference script runs AlphaFold-Gap, a hack proposed [here](https://twitter.com/minkbaek/status/1417538291709071362?lang=en), using the specified stock AlphaFold/OpenFold parameters (NOT AlphaFold-Multimer).
19 changes: 16 additions & 3 deletions openfold/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,19 @@
from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union
import numpy as np
import torch
from openfold.data import templates, parsers, mmcif_parsing, msa_identifiers, msa_pairing, feature_processing_multimer
from openfold.data.templates import get_custom_template_features, empty_template_feats
from openfold.data import (
templates,
parsers,
mmcif_parsing,
msa_identifiers,
msa_pairing,
feature_processing_multimer,
)
from openfold.data.templates import (
get_custom_template_features,
empty_template_feats,
CustomHitFeaturizer,
)
from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch
from openfold.np import residue_constants, protein

Expand All @@ -38,7 +49,9 @@ def make_template_features(
template_featurizer: Any,
) -> FeatureDict:
hits_cat = sum(hits.values(), [])
if(len(hits_cat) == 0 or template_featurizer is None):
if template_featurizer is None or (
len(hits_cat) == 0 and not isinstance(template_featurizer, CustomHitFeaturizer)
):
template_features = empty_template_feats(len(input_sequence))
else:
templates_result = template_featurizer.get_templates(
Expand Down
122 changes: 78 additions & 44 deletions openfold/data/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import json
import logging
import os
from pathlib import Path
import re
from typing import Any, Dict, Mapping, Optional, Sequence, Tuple

Expand Down Expand Up @@ -947,55 +948,71 @@ def _process_single_hit(


def get_custom_template_features(
mmcif_path: str,
query_sequence: str,
pdb_id: str,
chain_id: str,
kalign_binary_path: str):

with open(mmcif_path, "r") as mmcif_path:
cif_string = mmcif_path.read()

mmcif_parse_result = mmcif_parsing.parse(
file_id=pdb_id, mmcif_string=cif_string
)
template_sequence = mmcif_parse_result.mmcif_object.chain_to_seqres[chain_id]


mapping = {x:x for x, _ in enumerate(query_sequence)}


features, warnings = _extract_template_features(
mmcif_object=mmcif_parse_result.mmcif_object,
pdb_id=pdb_id,
mapping=mapping,
template_sequence=template_sequence,
query_sequence=query_sequence,
template_chain_id=chain_id,
kalign_binary_path=kalign_binary_path,
_zero_center_positions=True
)
features["template_sum_probs"] = [1.0]

# TODO: clean up this logic
template_features = {}
for template_feature_name in TEMPLATE_FEATURES:
template_features[template_feature_name] = []

for k in template_features:
template_features[k].append(features[k])

for name in template_features:
template_features[name] = np.stack(
template_features[name], axis=0
).astype(TEMPLATE_FEATURES[name])
mmcif_path: str,
query_sequence: str,
pdb_id: str,
chain_id: Optional[str] = "A",
kalign_binary_path: Optional[str] = None,
):
if os.path.isfile(mmcif_path):
template_paths = [Path(mmcif_path)]

elif os.path.isdir(mmcif_path):
template_paths = list(Path(mmcif_path).glob("*.cif"))
else:
logging.error("Custom template path %s does not exist", mmcif_path)
raise ValueError(f"Custom template path {mmcif_path} does not exist")

warnings = []
template_features = dict()
for template_path in template_paths:
logging.info("Featurizing template: %s", template_path)
# pdb_id only for error reporting, take file name
pdb_id = Path(template_path).stem
with open(template_path, "r") as mmcif_path:
cif_string = mmcif_path.read()
mmcif_parse_result = mmcif_parsing.parse(
file_id=pdb_id, mmcif_string=cif_string
)
# mapping skipping "-"
mapping = {
x: x for x, curr_char in enumerate(query_sequence) if curr_char.isalnum()
}
realigned_sequence, realigned_mapping = _realign_pdb_template_to_query(
old_template_sequence=query_sequence,
template_chain_id=chain_id,
mmcif_object=mmcif_parse_result.mmcif_object,
old_mapping=mapping,
kalign_binary_path=kalign_binary_path,
)
curr_features, curr_warnings = _extract_template_features(
mmcif_object=mmcif_parse_result.mmcif_object,
pdb_id=pdb_id,
mapping=realigned_mapping,
template_sequence=realigned_sequence,
query_sequence=query_sequence,
template_chain_id=chain_id,
kalign_binary_path=kalign_binary_path,
_zero_center_positions=True,
)
curr_features["template_sum_probs"] = [
1.0
] # template given by user, 100% confident
template_features = {
curr_name: template_features.get(curr_name, []) + [curr_item]
for curr_name, curr_item in curr_features.items()
}
warnings.append(curr_warnings)
template_features = {
template_feature_name: np.stack(
template_features[template_feature_name], axis=0
).astype(template_feature_type)
for template_feature_name, template_feature_type in TEMPLATE_FEATURES.items()
}
return TemplateSearchResult(
features=template_features, errors=None, warnings=warnings
)



@dataclasses.dataclass(frozen=True)
class TemplateSearchResult:
features: Mapping[str, Any]
Expand Down Expand Up @@ -1188,6 +1205,23 @@ def get_templates(
)


class CustomHitFeaturizer(TemplateHitFeaturizer):
"""Featurizer for templates given in folder.
Chain of interest has to be chain A and of same sequence length as input sequence."""
def get_templates(
self,
query_sequence: str,
hits: Sequence[parsers.TemplateHit],
) -> TemplateSearchResult:
"""Computes the templates for given query sequence (more details above)."""
logging.info("Featurizing mmcif_dir: %s", self._mmcif_dir)
return get_custom_template_features(
self._mmcif_dir,
query_sequence=query_sequence,
pdb_id="test",
chain_id="A",
kalign_binary_path=self._kalign_binary_path,
)
class HmmsearchHitFeaturizer(TemplateHitFeaturizer):
def get_templates(
self,
Expand Down
20 changes: 13 additions & 7 deletions run_pretrained_openfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,15 @@ def main(args):
)

is_multimer = "multimer" in args.config_preset

if is_multimer:
is_custom_template = "use_custom_template" in args and args.use_custom_template
if is_custom_template:
template_featurizer = templates.CustomHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date="9999-12-31", # just dummy, not used
max_hits=-1, # just dummy, not used
kalign_binary_path=args.kalign_binary_path
)
elif is_multimer:
template_featurizer = templates.HmmsearchHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date,
Expand All @@ -221,11 +228,9 @@ def main(args):
release_dates_path=args.release_dates_path,
obsolete_pdbs_path=args.obsolete_pdbs_path
)

data_processor = data_pipeline.DataPipeline(
template_featurizer=template_featurizer,
)

if is_multimer:
data_processor = data_pipeline.DataPipelineMultimer(
monomer_data_pipeline=data_processor,
Expand All @@ -238,7 +243,6 @@ def main(args):

np.random.seed(random_seed)
torch.manual_seed(random_seed + 1)

feature_processor = feature_pipeline.FeaturePipeline(config.data)
if not os.path.exists(output_dir_base):
os.makedirs(output_dir_base)
Expand Down Expand Up @@ -313,7 +317,6 @@ def main(args):
)

feature_dicts[tag] = feature_dict

processed_feature_dict = feature_processor.process_features(
feature_dict, mode='predict', is_multimer=is_multimer
)
Expand Down Expand Up @@ -400,6 +403,10 @@ def main(args):
help="""Path to alignment directory. If provided, alignment computation
is skipped and database path arguments are ignored."""
)
parser.add_argument(
"--use_custom_template", action="store_true", default=False,
help="""Use mmcif given with "template_mmcif_dir" argument as template input."""
)
parser.add_argument(
"--use_single_seq_mode", action="store_true", default=False,
help="""Use single sequence embeddings instead of MSAs."""
Expand Down Expand Up @@ -494,5 +501,4 @@ def main(args):
"""The model is being run on CPU. Consider specifying
--model_device for better performance"""
)

main(args)

0 comments on commit c48f850

Please sign in to comment.