diff --git a/doc/source/api/index_Align.rst b/doc/source/api/index_Align.rst index 4b252c315..06665bfb3 100644 --- a/doc/source/api/index_Align.rst +++ b/doc/source/api/index_Align.rst @@ -4,6 +4,12 @@ BioSimSpace.Align ================= The *Align* package provides functionality for aligning and merging molecules. + +.. automodule:: BioSimSpace.Align + +.. toctree:: + :maxdepth: 1 + Molecules are aligned using a Maximum Common Substructure (MCS) search, which is used to find mappings between atom indices in the two molecules. Functionality is provided for sorting the mappings according to a scoring @@ -34,8 +40,3 @@ Some examples: # The resulting "merged-molecule" can be used in free energy perturbation # simulations. merged = BSS.Align.merge(mol0, mol1, mappings) - -.. automodule:: BioSimSpace.Align - -.. toctree:: - :maxdepth: 1 diff --git a/doc/source/api/index_FreeEnergy.rst b/doc/source/api/index_FreeEnergy.rst index 0d931ecb0..ccf107b53 100644 --- a/doc/source/api/index_FreeEnergy.rst +++ b/doc/source/api/index_FreeEnergy.rst @@ -6,6 +6,13 @@ BioSimSpace.FreeEnergy The *FreeEnergy* package contains tools to configure, run, and analyse *relative* free energy simulations. +.. automodule:: BioSimSpace.FreeEnergy + +.. toctree:: + :maxdepth: 1 + +As well as the :class:`protocol ` used for production + Free-energy perturbation simulations require a :class:`System ` containing a *merged* molecule that can be *perturbed* between two molecular end states by use @@ -28,6 +35,9 @@ perturbable molecule created by merging two ligands, ``ligA`` and ``ligB``, perturbable molecule. We assume that each molecule/system has been appropriately minimised and equlibrated. +Relative binding free-energy (RBFE) +----------------------------------- + To setup, run, and analyse a binding free-energy calculation: .. code-block:: python @@ -130,12 +140,6 @@ the path to a working directory to :class:`FreeEnergy.Relative.analyse ` used for production simulations, it is also possible to use :class:`FreeEnergy.Relative ` to setup and run simulations for minimising or equilibrating structures for each lambda window. See the @@ -143,3 +147,73 @@ for minimising or equilibrating structures for each lambda window. See the :class:`FreeEnergyEquilibration ` protocols for details. At present, these protocols are only supported when not using :class:`SOMD ` as the simulation engine. + +Alchemical Transfer Method (ATM) +-------------------------------- + +This package contains tools to configure, run, and analyse *relative* free +energy simulations using the *alchemical transfer method* developed by the +`Gallicchio lab `. + +Only available in the *OpenMM* engine, the *alchemical transfer method* +replaces the conventional notion of perturbing between two end states with +a single system containing both the free and bound ligand. The relative free +energy of binding is then associated with the swapping of the bound and free +ligands. + +The *alchemical transfer method* has a few advantages over the conventional +approach, mainly arising from its relative simplicity and flexibility. The +method is particularly well-suited to the study of difficult ligand +transformations, such as scaffold-hopping and charge change perturbations. +The presence of both ligands in the same system also replaces the conventional +idea of _legs_, combining free, bound, forward and reverse legs into a +single simulation. + +In order to perform a relative free energy calculation using the +*alchemical transfer method*, the user requires a protein and two ligands, as +well as knowledge of any common core shared between the two ligands. +ATM-compatible systems can be created from these elements using the +:class:`FreeEnergy.ATM ` class. + +.. code-block:: python + + from BioSimSpace.FreeEnergy import ATMSetup + + ... + + # Create an ATM setup object. 'protein', 'ligand1' and 'ligand2' must be + # BioSimSpace Molecule objects. + # 'ligand1' is bound in the lambda=0 state, 'ligand2' is bound in the lambda=1 state. + atm_setup = ATMSetup(protein=protein, ligand1=ligand1, ligand2=ligand2) + + # Now create the BioSimSpace system. Here is where knowledge of the common core is required. + # ligand1_rigid_core and ligand2_rigid_core are lists of integers, each of length three, + # which define the indices of the common core atoms in the ligands. + # Displacement is the desired distance between the centre of masses of the two ligands. + system, data = atm_setup.prepare( + ligand1_rigid_core=[1, 2, 3], + ligand2_rigid_core=[1, 2, 3], + displacement=22.0 + ) + + # The prepare function returns two objects: a prepared BioSimSpace system that is ready + # for ATM simulation, and a data dictionary containing information relevant to ATM calculations. + # This dictionary does not need to be kept, as the information is also encoded in the system + # object, but it may be useful for debugging. + + +Preparing the system for production runs is slightly more complex than in +the conventional approach, as the system will need to be annealed to an +intermediate lambda value, and then equilibrated at that value. The +:ref:`protocol ` sub-module contains functionality for +equilibrating and annealing systems for ATM simulations. + +Once the production simulations have been completed, the user can analyse +the data using the :func:`analyse ` function. + +.. code-block:: python + + from BioSimSpace.FreeEnergy import ATM + + # Analyse the simulation data to get the free energy difference and associated error. + ddg, error = ATM.analyse("path/to/working/directory") diff --git a/doc/source/index.rst b/doc/source/index.rst index 3c20502b2..92ec687e1 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -63,7 +63,7 @@ Tutorials ========= .. toctree:: - :maxdepth: 2 + :maxdepth: 1 tutorials/index @@ -71,7 +71,7 @@ Detailed Guides =============== .. toctree:: - :maxdepth: 2 + :maxdepth: 1 guides/index @@ -93,7 +93,7 @@ Support Contributing ============ .. toctree:: - :maxdepth: 2 + :maxdepth: 1 contributing/index contributors diff --git a/doc/source/tutorials/alchemical_transfer.rst b/doc/source/tutorials/alchemical_transfer.rst new file mode 100644 index 000000000..0ce5afc69 --- /dev/null +++ b/doc/source/tutorials/alchemical_transfer.rst @@ -0,0 +1,307 @@ +========================== +Alchemical Transfer Method +========================== + +In this tutorial, you will use BioSimSpace to set up and run a Relative Binding +Free Energy (RBFE) calculation using the `alchemical transfer method +`__ (ATM) on a pair of ligands bound to +`Tyrosine kinase 2 `__ (TYK2). + +.. note :: + ATM calculations are currently only available in OpenMM. As such, an environment + containing OpenMM is required to run this tutorial. + +This tutorial assumes that you are familiar with the concepts of the Alchemical +Transfer Method. If you are not, please see the `Gallichio lab website +`__ as well as the corresponding +publications for an in-depth explanation of the method, as well as the +intricacies involved in setting up an ATM calculation. + +------------ +System Setup +------------ + +Import :mod:`BioSimSpace` using: + +>>> import BioSimSpace as BSS + +Now load the set of example molecules from a URL, via +:func:`BioSimSpace.IO.readMolecules`: + +>>> url = BSS.tutorialUrl() +>>> protein = BSS.IO.readMolecules([f"{url}/tyk2.prm7", f"{url}/tyk2.rst7"])[0] +>>> lig1 = BSS.IO.readMolecules([f"{url}/ejm_31.prm7", f"{url}/ejm_31.rst7"])[0] +>>> lig2 = BSS.IO.readMolecules([f"{url}/ejm_43.prm7", f"{url}/ejm_43.rst7"])[0] + +In order to run an ATM calculation, a single system containing both ligands and +the protein in their correct positions is needed. This can be created using +functionality provided in :func:`BioSimSpace.FreeEnergy.ATMSetup`. + +ATM calculations require that both ligands be present in the +system simultaneously, with one ligand bound to the receptor and the other free +in the solvent. As such, the first decision to be made when setting up an ATM +calculation is which ligand will be bound and which will be free. + +It is important to note that, while one ligand is chosen to be bound, `both` +ligands will bound to the receptor at some point during the calculation, the +choice made here simply defines the initial state of the system, and by +extension the `direction` of the calculation. + +The first step in creating an ATM-compatible system in BioSimSpace is to create +an :class:`ATMSetup` object, which will be used to prepare the system: + +>>> atm_setup = BSS.FreeEnergy.ATMSetup(receptor=protein, +... ligand_bound=lig1, +... ligand_free=lig2 +... ) + +Before an ATM-ready system can be prepared there are decisions to be made +regarding the system setup, namely which atoms will be used to define the rigid +cores of the ligands, as well those that make up the centre of mass of each +molecule. + +The choice of rigid core atoms is vital to the success of an ATM RBFE +calculation, and as such BioSimSpace provides a helper function to visualise the +choice made by the user. + +>>> BSS.FreeEnergy.ATMSetup.viewRigidCores( +... ligand_bound=lig1, +... ligand_free=lig2, +... ligand_bound_rigid_core=[14, 11, 15], +... ligand_free_rigid_core=[14, 11, 15] +... ) + +.. image:: images/alignment_visualisation.png + :alt: Visualisation of the rigid cores of the ligands. + +.. note :: + + In this case the choice of rigid core atoms is the same for both ligands, + but this is not always the case. The choice of these atoms should be made + on a ligand to ligand basis. + + For help in choosing the correct atoms, see the `Gallichio lab tutorial + `__. + +Now that a sensible choice of rigid core atoms has been made, there are a few +more choices to be made before the system can be prepared. The most important of +these is the choice of displacement vector, which defines the direction and +distance at which the free ligand will be placed relative to the bound ligand. +It is generally recommended that this displacement be at least 3 layers of water +molecules (> 10 Å) thick. If no displacement is provided a default choice of +[20Å, 20Å, 20Å] will be used. + +This is also the point at which a custom set of atoms can be chosen to define the +centre of mass of both the ligands and the protein. In the majority of cases it +should not be necessary to change the default choice of atoms, but the option is +there if needed and can be set using the ``ligand_bound_com_atoms`` and +``ligand_free_com_atoms`` arguments. + +Now that all the choices have been made, the system can be prepared: + +>>> system, atm_data = atm_setup.prepare( +... ligand_bound_rigid_core=[14, 11, 15], +... ligand_free_rigid_core=[14, 11, 15] +... ) + +The ``prepare`` function returns a pair of objects, the first is the prepared +protein-ligand-ligand system, and the second is a dictionary containing the +choices made during the setup process. This ``atm_data`` object will be passed to +protocols for minimisation, equilibration and production in order to ensure that +options chosen during setup are properly carried through. + +The prepared system can be visualised using BioSimSpace's built in visualisation +functionality: + +>>> v = BSS.Notebook.View(system) +>>> v.system() + +.. image:: images/tyk2_prepared.png + :alt: Visualisation of the prepared system. + +Now all that remains is to solvate the system. + +>>> solvated = BSS.Solvent.tip3p(molecule=system, box=3 * [7 * BSS.Units.Length.nanometer]) + +------------------------------ +Minimisation and Equilibration +------------------------------ + +Now that the system is fully prepared, the next step is to minimise and +equilibrate. The minimisation and equilibration of systems using alchemical +transfer is more complex than standard systems, and is a multi-stage process. + +First, if positional restraints are needed, which is generally recommended for +ATM calculations, the decision of which atoms to restrain must be made. A +good choice for these atoms are the alpha carbons of the protein. These can be +found using BioSimSpace search syntax: + +>>> ca = [atom.index() for atom in solvated.search("atomname CA")] + +The system can now be minimised. Unlike standard minimisation, the minimisation +of an ATM system requires that several restraints be applied from the start. +These restraints are: **core alignment**, applied to atoms determined earlier, which +can be turned on or off by passing the ``core_alignment`` argument; **positional +restraints** applied to the alpha carbons listed above, set using the +``restraint`` argument; and a **centre of mass distance restraint**, which maintains +the distance between the centre of masses of the ligands, as well as the +distance between the centre of mass of the protein and ligands, set using the +``com_distance_restraint`` argument. The strength of these restraints is automatically +set to a set of default values that are generally suitable for most systems, but +can also be set manually by passing the relevant arguments to +:data:`BioSimSpace.Protocol.ATMMinimisation`: + +>>> minimisation = BSS.Protocol.ATMMinimisation( +... data=atm_data, +... core_alignment=True, +... restraint=ca, +... com_distance_restraint=True +... ) + +This minimisation protocol can now be run as a standard BioSimSpace OpenMM +process: + +>>> minimisation_process = BSS.Process.OpenMM(solvated, minimisation) +>>> minimisation_process.start() +>>> minimisation_process.wait() +>>> minimised = minimisation_process.getSystem(block=True) + +Now the first stage of equilibration can be run. Similar to the minimisation, +this protocol has several restraints that are applied from the start: + +>>> equilibration = BSS.Protocol.ATMEquilibration( +... data=atm_data, +... core_alignment=True, +... restraint=ca, +... com_distance_restraint=True, +... runtime="100ps" +... ) +>>> equilibrate_process = BSS.Process.OpenMM(minimised, equilibration, platform="CUDA") +>>> equilibrate_process.start() +>>> equilibrate_process.wait() +>>> equilibrated = equilibrate_process.getSystem(block=True) + +.. note :: + The equilibration protocol is set to run for 100ps. This is a relatively + short time, and should be increased for production runs. + + Here the "CUDA" platform is explicitly set. It is highly recommended to use + a GPU platform for equilibration and production runs, as the calculations are + computationally expensive. + +Now that the system has been minimised and equilibrated without the ATMForce +present, it needs to be added to the system. The first stage of this +introduction is annealing, which by default will gradually increase the value of +λ from 0 to 0.5 over a number of cycles: + +>>> annealing = BSS.Protocol.ATMAnnealing( +... data=atm_data, +... core_alignment=True, +... restraint=ca, +... com_distance_restraint=True, +... runtime="100ps", +... anneal_numcycles=10 +... ) +>>> annealing_process = BSS.Process.OpenMM(equilibrated, annealing, platform="CUDA") +>>> annealing_process.start() +>>> annealing_process.wait() +>>> annealed = annealing_process.getSystem(block=True) + +The annealing process is fully customisable, and any number of λ-specific values +can be annealed. See :data:`BioSimSpace.Protocol.ATMAnnealing` for full the +full list of annealing options. + +The final stage of the ATM minimisation and equilibration protocol is a +post-annealing equilibration run, this time with the ATMForce present at λ=0.5: + +>>> post_anneal_equilibration = BSS.Protocol.ATMEquilibration( +... data=atm_data, +... core_alignment=True, +... restraint=ca, +... com_distance_restraint=True, +... use_atm_force=True, +... lambda1 = 0.5, +... lambda2 = 0.5, +... runtime="100ps" +... ) +>>> post_anneal_equilibration_process = BSS.Process.OpenMM( +... annealed, +... post_anneal_equilibration, +... platform="CUDA" +... ) +>>> post_anneal_equilibration_process.start() +>>> post_anneal_equilibration_process.wait() +>>> min_eq_final = post_anneal_equilibration_process.getSystem(block=True) + +.. note :: + A frequent source of instability in ATM production runs is an overlap between the + bound ligand and the protein after a swap in direction. If this is encountered + the first step taken should be to increase the runtime of the post-annealing equilibration. + This gives the system time to adjust to the presence of the new ligand, without the + reduced stability associated with a swap in direction. + +----------------------- +Production and Analysis +----------------------- + +The system is now ready for production. The key decision to be made before +beginning is the number of lambda windows, set using the ``num_lambda`` +argument. If this value is not set, a default of 22 will be set by BioSimSpace. + +.. note :: + Keep in mind that, due to the nature of the alchemical transfer method, a single + production run contains both the forward and reverse direction of both the free + and bound legs, and therefore a larger than usual number of lambda windows is + required for a well sampled result. + +In addition to setting the number of lambdas, any or all of the λ-specific +values can be manually set, with the only condition being that the lists +provided are all of the same length, specifically they must have length equal to +``num_lambda``. See :data:`BioSimSpace.Protocol.ATMProduction` for a full list +of options. + +In the case of this TYK2 perturbation, the default values for ``alpha`` and +``uh`` will need to be set manually, as the default values are not suitable. + +>>> alpha = 22 * [0.1] +>>> uh = 22 * [110.0] +>>> output_directory = "tyk2_atm" +>>> production_atm = BSS.Protocol.ATMProduction( +... data=atm_data, +... core_alignment=True, +... restraint=ca, +... com_distance_restraint=True, +... runtime = "1ns", +... num_lambda=22, +... alpha=alpha, +... uh=uh, +... ) +>>> production_process = BSS.FreeEnergy.ATM( +... system=min_eq_final, +... protocol=production_atm, +... work_dir=output_directory, +... platform="CUDA", +... setup_only=True +... ) + +The ``setup_only`` flag is set to ``True`` here, this means that all input files +will be created, but nothing will be run. It is recommended to run production +protocols on HPC resources where multiple GPUs are available, as the calculations +can be very computationally expensive. + +Running the generated inputs is as simple as running the ``OpenMM.py`` script +contained in each of the labelled ``lambda`` folders of the output directory. + +Once production is complete, the results can be analysed using the built-in +BioSimSpace UWHAM analysis tool. + +>>> BSS.FreeEnergy.ATM.analyse(output_directory) + +This will give the ΔΔG value for the perturbation, as well as the error (both in +kcal/mol). + +That concludes the tutorial on setting up and running an ATM RBFE calculation! +For further information please visit the :data:`API documentation +`, and for further information on the alchemical +transfer method, see the `Gallichio lab website +`__. diff --git a/doc/source/tutorials/images/alignment_visualisation.png b/doc/source/tutorials/images/alignment_visualisation.png new file mode 100644 index 000000000..9f931862d Binary files /dev/null and b/doc/source/tutorials/images/alignment_visualisation.png differ diff --git a/doc/source/tutorials/images/pfep_thermodynamic_cycle.png b/doc/source/tutorials/images/pfep_thermodynamic_cycle.png index 7a5191da5..75630f96f 100644 Binary files a/doc/source/tutorials/images/pfep_thermodynamic_cycle.png and b/doc/source/tutorials/images/pfep_thermodynamic_cycle.png differ diff --git a/doc/source/tutorials/images/tyk2_prepared.png b/doc/source/tutorials/images/tyk2_prepared.png new file mode 100644 index 000000000..df34cfce5 Binary files /dev/null and b/doc/source/tutorials/images/tyk2_prepared.png differ diff --git a/doc/source/tutorials/index.rst b/doc/source/tutorials/index.rst index 612ad44f6..52c92af05 100644 --- a/doc/source/tutorials/index.rst +++ b/doc/source/tutorials/index.rst @@ -24,3 +24,4 @@ please :doc:`ask for support. <../support>` hydration_freenrg metadynamics protein_mutations + alchemical_transfer diff --git a/doc/source/tutorials/protein_mutations.rst b/doc/source/tutorials/protein_mutations.rst index b75202552..28af5b8c3 100644 --- a/doc/source/tutorials/protein_mutations.rst +++ b/doc/source/tutorials/protein_mutations.rst @@ -1,6 +1,10 @@ +============================ Alchemical Protein Mutations ============================ +Introduction +============ + In this tutorial you will learn how to use BioSimSpace’s mapping functionality to set up alchemical calculations in order to compute the change in the binding affinity of a ligand as a result of a protein diff --git a/python/BioSimSpace/FreeEnergy/__init__.py b/python/BioSimSpace/FreeEnergy/__init__.py index 7021e19d8..e1e44c7af 100644 --- a/python/BioSimSpace/FreeEnergy/__init__.py +++ b/python/BioSimSpace/FreeEnergy/__init__.py @@ -29,6 +29,8 @@ :toctree: generated/ Relative + ATMSetup + ATM Functions ========= @@ -42,3 +44,4 @@ from ._relative import * from ._utils import * +from ._atm import * diff --git a/python/BioSimSpace/FreeEnergy/_atm.py b/python/BioSimSpace/FreeEnergy/_atm.py new file mode 100644 index 000000000..4d1fe516d --- /dev/null +++ b/python/BioSimSpace/FreeEnergy/_atm.py @@ -0,0 +1,1655 @@ +###################################################################### +# BioSimSpace: Making biomolecular simulation a breeze! +# +# Copyright: 2017-2024 +# +# Authors: Lester Hedges +# Matthew Burman +# +# BioSimSpace is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# BioSimSpace is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with BioSimSpace. If not, see . +###################################################################### + +# Functionality for creating and viewing systems for Atomic transfer. + +__all__ = ["ATMSetup", "ATM"] + +import copy as _copy +import json as _json +import os as _os +import pathlib as _pathlib +import shutil as _shutil +import warnings as _warnings +import zipfile as _zipfile + +from sire.legacy import IO as _SireIO + +from .._SireWrappers import Molecule as _Molecule +from .._SireWrappers import System as _System +from .. import _Utils +from ..Types import Length as _Length +from ..Types import Vector as _Vector +from ..Types import Coordinate as _Coordinate +from ..Align import matchAtoms as _matchAtoms +from ..Align import rmsdAlign as _rmsdAlign +from ..Notebook import View as _View +from .. import _isVerbose +from .. import _is_notebook +from ..Process import OpenMM as _OpenMM +from ..Process import ProcessRunner as _ProcessRunner + +if _is_notebook: + from IPython.display import FileLink as _FileLink + + +class ATMSetup: + """ + A class for setting up a system for ATM simulations. + """ + + def __init__( + self, + system=None, + receptor=None, + ligand_bound=None, + ligand_free=None, + protein_index=0, + ligand_bound_index=1, + ligand_free_index=2, + ): + """Constructor for the ATM class. + + Parameters + ---------- + + system : :class:`System ` + A pre-prepared ATM system containing protein and ligands placed + in their correct positions. If provided takes precedence over + protein, ligand_bound and ligand_free. + + receptor : :class:`Molecule ` + A receptor molecule. Will be used along with ligand_bound and + ligand_free to create a system. + + ligand_bound : :class:`Molecule ` + The bound ligand. Will be used along with protein and ligand_free + to create a system. + + ligand_free : :class:`Molecule ` + The free ligand. Will be used along with protein and ligand_bound + to create a system. + + protein_index : int, [int] + If passing a pre-prepared system, the index (or indices) of the + protein molecule in the system (Default 0). + + ligand_bound_index : int + If passing a pre-prepared system, the index of the bound ligand + molecule in the system (Default 1). + + ligand_free_index : int + If passing a pre-prepared system, the index of the free ligand + molecule in the system (Default 2). + """ + # make sure that either system or protein, ligand_bound and ligand_free are given + if system is None and not all( + x is not None for x in [receptor, ligand_bound, ligand_free] + ): + raise ValueError( + "Either a pre-prepared system or protein, bound ligand and free ligand must be given." + ) + # check that the system is a BioSimSpace system + # or the other inputs are BioSimSpace molecules + if system is not None and not isinstance(system, _System): + raise ValueError("The system must be a BioSimSpace System object.") + elif not all( + isinstance(x, _Molecule) + for x in [receptor, ligand_bound, ligand_free] + if x is not None + ): + raise ValueError( + "The protein, bound ligand and free ligand must be BioSimSpace Molecule objects." + ) + self._is_prepared = False + self._setSystem(system) + if not self._is_prepared: + self._setProtein(receptor) + self._setLigandBound(ligand_bound) + self._setLigandFree(ligand_free) + else: + self._setProteinIndex(protein_index) + self._setLigandBoundIndex(ligand_bound_index) + self._setLigandFreeIndex(ligand_free_index) + + def _setSystem(self, system, is_prepared=True): + """ + Set the system for the ATM simulation. + + Parameters + ---------- + + system : BioSimSpace._SireWrappers.System + The system for the ATM simulation. + """ + if system is not None: + if not isinstance(system, _System): + raise ValueError( + f"The system must be a BioSimSpace System object. It is currently {type(system)}." + ) + elif len(system.getMolecules()) < 3: + raise ValueError( + "The system must contain at least three molecules (a protein and two ligands)." + ) + else: + self._system = system + self._is_prepared = is_prepared + else: + self._system = None + self._is_prepared = False + + def _getSystem(self): + """Get the system for the ATM simulation. + + Returns + ------- + BioSimSpace._SireWrappers.System + The system for the ATM simulation. + """ + return self._system + + def _setProtein(self, protein): + """Set the protein for the ATM simulation. + + Parameters + ---------- + + protein : BioSimSpace._SireWrappers.Molecule + The protein for the ATM simulation. + """ + if protein is not None: + if not isinstance(protein, _Molecule): + raise ValueError("The protein must be a BioSimSpace Molecule object.") + else: + self._protein = protein + else: + self._protein = None + + def _getProtein(self): + """Get the protein for the ATM simulation. + + Returns + ------- + BioSimSpace._SireWrappers.Molecule + The protein for the ATM simulation. + """ + return self._protein + + def _setLigandBound(self, ligand_bound): + """Set the bound ligand for the ATM simulation. + + Parameters + ---------- + + ligand_bound : BioSimSpace._SireWrappers.Molecule + The bound ligand for the ATM simulation. + """ + if ligand_bound is not None: + if not isinstance(ligand_bound, _Molecule): + raise ValueError( + "The bound ligand must be a BioSimSpace Molecule object." + ) + else: + self._ligand_bound = ligand_bound + else: + self._ligand_bound = None + + def _getLigandBound(self): + """Get the bound ligand for the ATM simulation. + + Returns + ------- + BioSimSpace._SireWrappers.Molecule + The bound ligand for the ATM simulation. + """ + return self._ligand_bound + + def _setLigandFree(self, ligand_free): + """Set the free ligand for the ATM simulation. + + Parameters + ---------- + + ligand_free : BioSimSpace._SireWrappers.Molecule + The free ligand for the ATM simulation. + """ + if ligand_free is not None: + if not isinstance(ligand_free, _Molecule): + raise ValueError( + "The free ligand must be a BioSimSpace Molecule object." + ) + else: + self._ligand_free = ligand_free + else: + self._ligand_free = None + + def _getLigandFree(self): + """Get the free ligand for the ATM simulation. + + Returns + ------- + BioSimSpace._SireWrappers.Molecule + The free ligand for the ATM simulation. + """ + return self._ligand_free + + def _setDisplacement(self, displacement): + """Set the displacement of the free ligand along the normal vector.""" + if isinstance(displacement, str): + try: + self._displacement = _Length(displacement) + except Exception as e: + raise ValueError( + f"Could not convert {displacement} to a BSS length, due to the following error: {e}" + ) + elif isinstance(displacement, _Length): + self._displacement = displacement + elif isinstance(displacement, list): + if len(displacement) != 3: + raise ValueError("displacement must have length 3") + if all(isinstance(x, (float, int)) for x in displacement): + self._displacement = _Vector(*displacement) + elif all(isinstance(x, _Length) for x in displacement): + self._displacement = _Vector([x.value() for x in displacement]) + else: + raise TypeError("displacement must be a list of floats or BSS lengths") + elif isinstance(displacement, _Vector): + self._displacement = displacement + else: + raise TypeError( + f"displacement must be a string, BSS length or list. It is currently {type(displacement)}." + ) + if self._is_prepared: + if not isinstance(self._displacement, _Vector): + raise ValueError( + "Displacement must be a vector or list if a pre-prepared system is given" + ) + + def _getDisplacement(self): + """Get the displacement of the free ligand along the normal vector. + + Returns + ------- + BioSimSpace.Types.Length + The displacement of the free ligand along the normal vector. + """ + return self._displacement + + def _setLigandBoundRigidCore(self, ligand_bound_rigid_core): + """Set the indices for the rigid core atoms of ligand 1. + + Parameters + ---------- + + ligand_bound_rigid_core : BioSimSpace._SireWrappers.Molecule + The rigid core of the bound ligand for the ATM simulation. + """ + if ligand_bound_rigid_core is None: + self._ligand_bound_rigid_core = None + else: + if not isinstance(ligand_bound_rigid_core, list): + raise TypeError("ligand_bound_rigid_core must be a list") + if len(ligand_bound_rigid_core) != 3: + raise ValueError("ligand_bound_rigid_core must have length 3") + # make sure all indices are ints + if not all(isinstance(x, int) for x in ligand_bound_rigid_core): + raise TypeError("ligand_bound_rigid_core must contain only integers") + if any(x >= self._ligand_bound_atomcount for x in ligand_bound_rigid_core): + raise ValueError( + "ligand_bound_rigid_core contains an index that is greater than the number of atoms in the ligand" + ) + self._ligand_bound_rigid_core = ligand_bound_rigid_core + + def _getLigandBoundRigidCore(self): + """Get the indices for the rigid core atoms of ligand 1. + + Returns + ------- + list + The indices for the rigid core atoms of ligand 1. + """ + return self._ligand_bound_rigid_core + + def _setLigandFreeRigidCore(self, ligand_free_rigid_core): + """Set the indices for the rigid core atoms of ligand 2. + + Parameters + ---------- + + ligand_free_rigid_core : BioSimSpace._SireWrappers.Molecule + The rigid core of the free ligand for the ATM simulation. + """ + if ligand_free_rigid_core is None: + self._ligand_free_rigid_core = None + else: + if not isinstance(ligand_free_rigid_core, list): + raise TypeError("ligand_free_rigid_core must be a list") + if len(ligand_free_rigid_core) != 3: + raise ValueError("ligand_free_rigid_core must have length 3") + # make sure all indices are ints + if not all(isinstance(x, int) for x in ligand_free_rigid_core): + raise TypeError("ligand_free_rigid_core must contain only integers") + if any(x >= self._ligand_free_atomcount for x in ligand_free_rigid_core): + raise ValueError( + "ligand_free_rigid_core contains an index that is greater than the number of atoms in the ligand" + ) + self._ligand_free_rigid_core = ligand_free_rigid_core + + def _getLigandFreeRigidCore(self): + """Get the indices for the rigid core atoms of ligand 2. + + Returns + ------- + list + The indices for the rigid core atoms of ligand 2. + """ + return self._ligand_free_rigid_core + + def _setProteinIndex(self, protein_index): + """ + Set the index of the protein in the system + + Parameters + ---------- + + protein_index : list + The index or indices of the protein in the system. + """ + if isinstance(protein_index, list): + # check that all elements are ints + if not all(isinstance(x, int) for x in protein_index): + raise TypeError("protein_index must be a list of ints or a single int") + for p in protein_index: + if p < 0: + raise ValueError("protein_index must be a positive integer") + if self._system[p].isWater(): + _warnings.warn( + f"The molecule at index {p} is a water molecule, check your protein_index list." + ) + self.protein_index = protein_index + elif isinstance(protein_index, int): + self.protein_index = [protein_index] + else: + raise TypeError("protein_index must be an int or a list of ints") + + def _getProteinIndex(self): + """Get the index of the protein molecule in the system. + + Returns + ------- + int + The index of the protein molecule in the system. + """ + return self.protein_index + + def _setLigandBoundIndex(self, ligand_bound_index): + """Set the index of the bound ligand molecule in the system. + + Parameters + ---------- + + ligand_bound_index : int + The index of the bound ligand molecule in the system. + """ + if not isinstance(ligand_bound_index, int): + raise ValueError("ligand_bound_index must be an integer.") + else: + if ligand_bound_index < 0: + raise ValueError("ligand_bound_index must be a positive integer") + if self._system[ligand_bound_index].isWater(): + _warnings.warn( + f"The molecule at index {ligand_bound_index} is a water molecule, check your ligand_bound_index." + ) + self._ligand_bound_index = ligand_bound_index + + def _getLigandBoundIndex(self): + """Get the index of the bound ligand molecule in the system. + + Returns + ------- + int + The index of the bound ligand molecule in the system. + """ + return self._ligand_bound_index + + def _setLigandFreeIndex(self, ligand_free_index): + """Set the index of the free ligand molecule in the system. + + Parameters + ---------- + + ligand_free_index : int + The index of the free ligand molecule in the system. + """ + if not isinstance(ligand_free_index, int): + raise ValueError("ligand_free_index must be an integer.") + else: + if ligand_free_index < 0: + raise ValueError("ligand_free_index must be a positive integer") + if self._system[ligand_free_index].isWater(): + _warnings.warn( + f"The molecule at index {ligand_free_index} is a water molecule, check your ligand_free_index." + ) + self._ligand_free_index = ligand_free_index + + def _getLigandFreeIndex(self): + """Get the index of the free ligand molecule in the system. + + Returns + ------- + int + The index of the free ligand molecule in the system. + """ + return self._ligand_free_index + + def prepare( + self, + ligand_bound_rigid_core, + ligand_free_rigid_core, + displacement="20A", + protein_com_atoms=None, + ligand_bound_com_atoms=None, + ligand_free_com_atoms=None, + ): + """ + Prepare the system for an ATM simulation. + + Parameters + ---------- + + ligand_bound_rigid_core : [int] + A list of three atom indices that define the rigid core of the bound ligand. + Indices are set relative to the ligand, not the system and are 0-indexed. + + ligand_free_rigid_core : [int] + A list of three atom indices that define the rigid core of the free ligand. + Indices are set relative to the ligand, not the system and are 0-indexed. + + displacement : float, string, [float, float, float] + The diplacement between the bound and free ligands. + If a float or string is given, BioSimSpace will attempt to find the ideal + vector along which to displace the ligand by the given magnitude. If a list + is given, the vector will be used directly. + + protein_com_atoms : [int] + A list of atom indices that define the center of mass of the protein. + If None, the center of mass of the protein will be found automatically. + + ligand_bound_com_atoms : [int] + A list of atom indices that define the center of mass of the bound ligand. + If None, the center of mass of the bound ligand will be found automatically. + + ligand_free_com_atoms : [int] + A list of atom indices that define the center of mass of the free ligand. + If None, the center of mass of the free ligand will be found automatically. + + Returns + ------- + + system : :class:`System ` + The prepared system, including protein and ligands in their correct positions. + + data : dict + A dictionary containing the data needed for the ATM simulation. This is + also encoded in the system for consistency, but is returned so that the + user can easily query and validate the data. + """ + if self._is_prepared: + self._systemInfo() + self._setLigandBoundRigidCore(ligand_bound_rigid_core) + self._setLigandFreeRigidCore(ligand_free_rigid_core) + self._setDisplacement(displacement) + self._setProtComAtoms(protein_com_atoms) + self._setLig1ComAtoms(ligand_bound_com_atoms) + self._setLig2ComAtoms(ligand_free_com_atoms) + + self._findAtomIndices() + self._makeData() + serialisable_disp = [ + self._displacement.x(), + self._displacement.y(), + self._displacement.z(), + ] + temp_data = self.data.copy() + temp_data["displacement"] = serialisable_disp + self._system._sire_object.setProperty("atom_data", _json.dumps(temp_data)) + return self._system, self.data + + else: + # A bit clunky, but setDisplacement needs to be called twice - before and after _makeSystemFromThree + # the final value will be set after the system is made, but the initial value is needed to make the system + self._setDisplacement(displacement) + system, prot_ind, lig1_ind, lig2_ind, dis_vec = self._makeSystemFromThree( + self._protein, self._ligand_bound, self._ligand_free, self._displacement + ) + self._setSystem(system, is_prepared=False) + self._setDisplacement(dis_vec) + self._setProteinIndex(prot_ind) + self._setLigandBoundIndex(lig1_ind) + self._setLigandFreeIndex(lig2_ind) + self._systemInfo() + self._setLigandBoundRigidCore(ligand_bound_rigid_core) + self._setLigandFreeRigidCore(ligand_free_rigid_core) + self._setProtComAtoms(protein_com_atoms) + self._setLig1ComAtoms(ligand_bound_com_atoms) + self._setLig2ComAtoms(ligand_free_com_atoms) + self._findAtomIndices() + self._makeData() + serialisable_disp = [ + self._displacement.x(), + self._displacement.y(), + self._displacement.z(), + ] + temp_data = self.data.copy() + temp_data["displacement"] = serialisable_disp + # encode data in system for consistency + self._system._sire_object.setProperty("atom_data", _json.dumps(temp_data)) + return self._system, self.data + + @staticmethod + def _makeSystemFromThree(protein, ligand_bound, ligand_free, displacement): + """Create a system for ATM simulations. + + Parameters + ---------- + + protein : BioSimSpace._SireWrappers.Molecule + The protein for the ATM simulation. + + ligand_bound : BioSimSpace._SireWrappers.Molecule + The bound ligand for the ATM simulation. + + ligand_free : BioSimSpace._SireWrappers.Molecule + The free ligand for the ATM simulation. + + displacement : BioSimSpace.Types.Length + The displacement of the ligand along the normal vector. + + Returns + ------- + + BioSimSpace._SireWrappers.System + The system for the ATM simulation. + """ + + def _findTranslationVector(system, displacement, protein, ligand): + + from sire.legacy.Maths import Vector + + if not isinstance(system, _System): + raise TypeError("system must be a BioSimSpace system") + if not isinstance(protein, (_Molecule, type(None))): + raise TypeError("protein must be a BioSimSpace molecule") + if not isinstance(ligand, (_Molecule, type(None))): + raise TypeError("ligand must be a BioSimSpace molecule") + + # Assume that binding sire is the center of mass of the ligand + binding = _Coordinate(*ligand._getCenterOfMass()) + + # Create grid around the binding site + # This will act as the search region + grid_length = _Length(20.0, "angstroms") + + num_edges = 5 + search_radius = (grid_length / num_edges) / 2 + grid_min = binding - 0.5 * grid_length + grid_max = binding + 0.5 * grid_length + + non_protein_coords = Vector() + # Count grid squares that contain no protein atoms + num_non_prot = 0 + + import numpy as np + + # Loop over the grid + for x in np.linspace(grid_min.x().value(), grid_max.x().value(), num_edges): + for y in np.linspace( + grid_min.y().value(), grid_max.y().value(), num_edges + ): + for z in np.linspace( + grid_min.z().value(), grid_max.z().value(), num_edges + ): + search = ( + f"atoms within {search_radius.value()} of ({x}, {y}, {z})" + ) + + try: + protein.search(search) + except: + non_protein_coords += Vector(x, y, z) + num_non_prot += 1 + + non_protein_coords /= num_non_prot + non_protein_coords = _Coordinate._from_sire_vector(non_protein_coords) + + # Now search out alpha carbons in system + x = binding.x().angstroms().value() + y = binding.y().angstroms().value() + z = binding.z().angstroms().value() + string = f"(atoms within 10 of {x},{y},{z}) and atomname CA" + + try: + search = system.search(string) + except: + _warnings.warn( + "No alpha carbons found in system, falling back on any carbon atoms." + ) + try: + string = f"(atoms within 10 of {x},{y},{z}) and element C" + search = system.search(string) + except: + raise ValueError("No carbon atoms found in system") + + com = _Coordinate(_Length(0, "A"), _Length(0, "A"), _Length(0, "A")) + atoms1 = [] + for atom in search: + com += atom.coordinates() + atoms1.append(system.getIndex(atom)) + com /= search.nResults() + + initial_normal_vector = (non_protein_coords - com).toVector().normalise() + + out_of_protein = displacement.value() * initial_normal_vector + return out_of_protein + + mapping = _matchAtoms(ligand_free, ligand_bound) + ligand_free_aligned = _rmsdAlign(ligand_free, ligand_bound, mapping) + prot_lig1 = (protein + ligand_bound).toSystem() + + if isinstance(displacement, _Vector): + ligand_free_aligned.translate( + [displacement.x(), displacement.y(), displacement.z()] + ) + vec = displacement + else: + vec = _findTranslationVector(prot_lig1, displacement, protein, ligand_bound) + ligand_free_aligned.translate([vec.x(), vec.y(), vec.z()]) + + sys = (protein + ligand_bound + ligand_free_aligned).toSystem() + prot_ind = sys.getIndex(protein) + lig1_ind = sys.getIndex(ligand_bound) + lig2_ind = sys.getIndex(ligand_free_aligned) + return sys, prot_ind, lig1_ind, lig2_ind, vec + + def _systemInfo(self): + """ + If the user gives a pre-prepared ATM system, extract the needed information. + """ + for p in self.protein_index: + if self._system[p].isWater(): + _warnings.warn( + f"The molecule at index {self.protein_index} appears to be a water molecule." + " This should be a protein." + ) + if self._system[self._ligand_bound_index].isWater(): + _warnings.warn( + f"The molecule at index {self._ligand_bound_index} appears to be a water molecule." + " This should be the bound ligand." + ) + if self._system[self._ligand_free_index].isWater(): + _warnings.warn( + f"The molecule at index {self._ligand_free_index} appears to be a water molecule." + " This should be the free ligand." + ) + self._protein_atomcount = sum( + self._system[i].nAtoms() for i in self.protein_index + ) + self._ligand_bound_atomcount = self._system[self._ligand_bound_index].nAtoms() + self._ligand_free_atomcount = self._system[self._ligand_free_index].nAtoms() + + def _findAtomIndices(self): + """ + Find the indices of the protein and ligand atoms in the system + + Returns + ------- + dict + A dictionary containing the indices of the protein and ligand atoms in the system + """ + protein_atom_start = self._system[self.protein_index[0]].getAtoms()[0] + protein_atom_end = self._system[self.protein_index[-1]].getAtoms()[-1] + self._first_protein_atom_index = self._system.getIndex(protein_atom_start) + self._last_protein_atom_index = self._system.getIndex(protein_atom_end) + + ligand_bound_atom_start = self._system[self._ligand_bound_index].getAtoms()[0] + ligand_bound_atom_end = self._system[self._ligand_bound_index].getAtoms()[-1] + self._first_ligand_bound_atom_index = self._system.getIndex( + ligand_bound_atom_start + ) + self._last_ligand_bound_atom_index = self._system.getIndex( + ligand_bound_atom_end + ) + + ligand_free_atom_start = self._system[self._ligand_free_index].getAtoms()[0] + ligand_free_atom_end = self._system[self._ligand_free_index].getAtoms()[-1] + self._first_ligand_free_atom_index = self._system.getIndex( + ligand_free_atom_start + ) + self._last_ligand_free_atom_index = self._system.getIndex(ligand_free_atom_end) + + def _getProtComAtoms(self): + """ + Get the atoms that define the center of mass of the protein as a list of ints + + Returns + ------- + list + A list of atom indices that define the center of mass of the protein. + """ + return self._mol1_com_atoms + + def _setProtComAtoms(self, prot_com_atoms): + """ + Set the atoms that define the center of mass of the protein + If a list is given, simply set them according to the list. + If None, find them based on the center of mass of the protein. + """ + if prot_com_atoms is not None: + # Make sure its a list of ints + if not isinstance(prot_com_atoms, list): + raise TypeError("mol1_com_atoms must be a list") + if not all(isinstance(x, int) for x in prot_com_atoms): + raise TypeError("mol1_com_atoms must be a list of ints") + self._mol1_com_atoms = prot_com_atoms + else: + # Find com of the protein + if self._is_prepared: + temp_system = self._system._sire_object + protein = temp_system[self.protein_index[0]] + for i in self.protein_index[1:]: + protein += temp_system[i] + com = protein.coordinates() + self._mol1_com_atoms = [ + a.index().value() + for a in protein[f"atoms within 11 angstrom of {com}"] + ] + del temp_system + del protein + else: + protein = self._protein + com = protein._sire_object.coordinates() + self._mol1_com_atoms = [ + a.index().value() + for a in protein._sire_object[f"atoms within 11 angstrom of {com}"] + ] + + def _getLig1ComAtoms(self): + """ + Get the atoms that define the center of mass of the bound ligand as a list of ints + + Returns + ------- + list + A list of atom indices that define the center of mass of the bound ligand. + """ + return self._lig1_com_atoms + + def _setLig1ComAtoms(self, lig1_com_atoms): + """ + Set the atoms that define the center of mass of the bound ligand + If a list is given, simply set them according to the list. + If None, find them based on the center of mass of the bound ligand. + In most cases this will be all atoms within the ligand + """ + if lig1_com_atoms is not None: + # Make sure its a list of ints + if not isinstance(lig1_com_atoms, list): + raise TypeError("lig1_com_atoms must be a list") + if not all(isinstance(x, int) for x in lig1_com_atoms): + raise TypeError("lig1_com_atoms must be a list of ints") + self._lig1_com_atoms = lig1_com_atoms + else: + # Find com of the ligand + if self._is_prepared: + ligand_bound = self._system[self._ligand_bound_index] + else: + ligand_bound = self._ligand_bound + com = ligand_bound._sire_object.coordinates() + self._lig1_com_atoms = [ + a.index().value() + for a in ligand_bound._sire_object[f"atoms within 11 angstrom of {com}"] + ] + + def _getLig2ComAtoms(self): + """ + Get the atoms that define the center of mass of the free ligand as a list of ints + + Returns + ------- + list + A list of atom indices that define the center of mass of the free ligand. + """ + return self._lig2_com_atoms + + def _setLig2ComAtoms(self, lig2_com_atoms): + """ + Set the atoms that define the center of mass of the free ligand + If a list is given, simply set them according to the list. + If None, find them based on the center of mass of the free ligand. + In most cases this will be all atoms within the ligand + """ + if lig2_com_atoms is not None: + # Make sure its a list of ints + if not isinstance(lig2_com_atoms, list): + raise TypeError("lig2_com_atoms must be a list") + if not all(isinstance(x, int) for x in lig2_com_atoms): + raise TypeError("lig2_com_atoms must be a list of ints") + self._lig2_com_atoms = lig2_com_atoms + else: + # Find com of the ligand + if self._is_prepared: + ligand_free = self._system[self._ligand_free_index] + else: + ligand_free = self._ligand_free + com = ligand_free._sire_object.coordinates() + self._lig2_com_atoms = [ + a.index().value() + for a in ligand_free._sire_object[f"atoms within 11 angstrom of {com}"] + ] + + def _makeData(self): + """ + Make the data dictionary for the ATM system + """ + self.data = {} + self.data["displacement"] = self._getDisplacement() + self.data["protein_index"] = self._getProteinIndex() + self.data["ligand_bound_index"] = self._getLigandBoundIndex() + self.data["ligand_free_index"] = self._getLigandFreeIndex() + self.data["ligand_bound_rigid_core"] = self._getLigandBoundRigidCore() + self.data["ligand_free_rigid_core"] = self._getLigandFreeRigidCore() + self.data["mol1_atomcount"] = self._protein_atomcount + self.data["ligand_bound_atomcount"] = self._ligand_bound_atomcount + self.data["ligand_free_atomcount"] = self._ligand_free_atomcount + self.data["first_protein_atom_index"] = self._first_protein_atom_index + self.data["last_protein_atom_index"] = self._last_protein_atom_index + self.data["first_ligand_bound_atom_index"] = self._first_ligand_bound_atom_index + self.data["last_ligand_bound_atom_index"] = self._last_ligand_bound_atom_index + self.data["first_ligand_free_atom_index"] = self._first_ligand_free_atom_index + self.data["last_ligand_free_atom_index"] = self._last_ligand_free_atom_index + self.data["protein_com_atoms"] = self._mol1_com_atoms + self.data["ligand_bound_com_atoms"] = self._lig1_com_atoms + self.data["ligand_free_com_atoms"] = self._lig2_com_atoms + + @staticmethod + def viewRigidCores( + system=None, + ligand_bound=None, + ligand_free=None, + ligand_bound_rigid_core=None, + ligand_free_rigid_core=None, + ): + """ + View the rigid cores of the ligands. + Rigid core atoms within the bound ligand are shown in green, those within the free ligand are shown in red. + + Parameters + ---------- + + system : :class:`System ` + The system for the ATM simulation that has been prepared ATM.prepare(). + All other arguments are ignored if this is provided. + + ligand_bound : :class:`Molecule ` + The bound ligand. + + ligand_free : :class:`Molecule ` + The free ligand. + + ligand_bound_rigid_core : list + The indices for the rigid core atoms of the bound ligand. + + ligand_free_rigid_core : list + The indices for the rigid core atoms of the free ligand. + """ + import math as _math + + def move_to_origin(lig): + com = _Coordinate(*lig._getCenterOfMass()) + lig.translate([-com.x().value(), -com.y().value(), -com.z().value()]) + + def euclidean_distance(point1, point2): + return _math.sqrt( + (point1[0] - point2[0]) ** 2 + + (point1[1] - point2[1]) ** 2 + + (point1[2] - point2[2]) ** 2 + ) + + def furthest_points(points): + max_distance = 0 + furthest_pair = None + n = len(points) + + if n < 2: + return None, None, 0 # Not enough points to compare + + for i in range(n): + for j in range(i + 1, n): + distance = euclidean_distance(points[i], points[j]) + if distance > max_distance: + max_distance = distance + furthest_pair = (points[i], points[j]) + + return furthest_pair[0], furthest_pair[1], max_distance + + def vector_from_points(point1, point2): + dx = point2[0] - point1[0] + dy = point2[1] - point1[1] + dz = point2[2] - point1[2] + + magnitude = _math.sqrt(dx**2 + dy**2 + dz**2) + if magnitude == 0: + return (0, 0, 0) + + return (dx / magnitude, dy / magnitude, dz / magnitude) + + def find_arrow_points(center1, center2, diameter1, diameter2): + import numpy as np + + # logic to make sure arrows pointing from spheres start on their edge, not centre. + # Convert the points to numpy arrays + p1 = np.array(center1) + p2 = np.array(center2) + + # Calculate the radii from the diameters + radius1 = diameter1 / 2 + radius2 = diameter2 / 2 + + # Calculate the vector between the two centers + v = p2 - p1 + + # Calculate the magnitude (distance between the two centers) + dist = np.linalg.norm(v) + + # Normalize the vector + v_norm = v / dist + + # Calculate the start and end points of the arrow + start_point = p1 + radius1 * v_norm # From the surface of Sphere 1 + end_point = p2 - radius2 * v_norm # To the surface of Sphere 2 + + return start_point, end_point + + # if a system is provided, check that it has the "atom_data" property + if system is not None: + sdata = _json.loads(system._sire_object.property("atom_data").value()) + local_s = system.copy() + ligand_bound = local_s[sdata["ligand_bound_index"]] + move_to_origin(ligand_bound) + ligand_free = local_s[sdata["ligand_free_index"]] + move_to_origin(ligand_free) + ligand_bound_rigid_core = sdata["ligand_bound_rigid_core"] + ligand_free_rigid_core = sdata["ligand_free_rigid_core"] + + # if not system provided, ALL other parameters must be provided + else: + if ligand_bound is None: + raise ValueError("ligand_bound must be provided") + if ligand_free is None: + raise ValueError("ligand_free must be provided") + if ligand_bound_rigid_core is None: + raise ValueError("ligand_bound_rigid_core must be provided") + if ligand_free_rigid_core is None: + raise ValueError("ligand_free_rigid_core must be provided") + + if not isinstance(ligand_bound, _Molecule): + raise TypeError("ligand_bound must be a BioSimSpace molecule") + if not isinstance(ligand_free, _Molecule): + raise TypeError("ligand_free must be a BioSimSpace molecule") + if not isinstance(ligand_bound_rigid_core, list): + raise TypeError("ligand_bound_rigid_core must be a list") + elif not len(ligand_bound_rigid_core) == 3: + raise ValueError("ligand_bound_rigid_core must have length 3") + if not isinstance(ligand_free_rigid_core, list): + raise TypeError("ligand_free_rigid_core must be a list") + elif not len(ligand_free_rigid_core) == 3: + raise ValueError("ligand_free_rigid_core must have length 3") + + # copy the ligands + ligand_bound = ligand_bound.copy() + move_to_origin(ligand_bound) + ligand_free = ligand_free.copy() + move_to_origin(ligand_free) + + pre_translation_lig1_core_coords = [] + + for i in ligand_bound_rigid_core: + x = ligand_bound.getAtoms()[i].coordinates().x().value() + y = ligand_bound.getAtoms()[i].coordinates().y().value() + z = ligand_bound.getAtoms()[i].coordinates().z().value() + pre_translation_lig1_core_coords.append((x, y, z)) + + point1, point2, distance = furthest_points(pre_translation_lig1_core_coords) + vector = vector_from_points(point1, point2) + + # need to know the size of ligand_bound + lig1_coords = [] + for i in ligand_bound.getAtoms(): + x = i.coordinates().x().value() + y = i.coordinates().y().value() + z = i.coordinates().z().value() + lig1_coords.append((x, y, z)) + + lig1_point1, lig1_point2, lig1_distance = furthest_points(lig1_coords) + + # Translate ligand_free so they don't overlap + ligand_free.translate( + [ + -1.0 * lig1_distance * 2 * vector[0], + -1.0 * lig1_distance * 2 * vector[1], + -1.0 * lig1_distance * 2 * vector[2], + ] + ) + # Get coords of rigid core atoms + ligand_bound_core_coords = [] + ligand_free_core_coords = [] + for i in ligand_bound_rigid_core: + ligand_bound_core_coords.append(ligand_bound.getAtoms()[i].coordinates()) + for i in ligand_free_rigid_core: + ligand_free_core_coords.append(ligand_free.getAtoms()[i].coordinates()) + + # Create molecule containing both ligands + mol = ligand_bound + ligand_free + + # Create view + view = _ViewAtoM(mol) + + # Create nglview object + ngl = view.system(mol) + + ngl.add_ball_and_stick("all", opacity=0.5) + + # Add spheres to rigid core locations - first the obund ligand with red spheres + for coord1, core_atom_1 in zip( + ligand_bound_core_coords, + ligand_bound_rigid_core, + ): + ngl.shape.add_sphere( + [coord1.x().value(), coord1.y().value(), coord1.z().value()], + [0, 1, 0], + 0.45, + ) + ngl.shape.add( + "text", + [coord1.x().value(), coord1.y().value(), coord1.z().value() - 0.9], + [0, 0, 0], + 2.5, + f"{core_atom_1}", + ) + + # now the free ligand with black spheres + for coord1, core_atom_1 in zip( + ligand_free_core_coords, + ligand_free_rigid_core, + ): + ngl.shape.add_sphere( + [coord1.x().value(), coord1.y().value(), coord1.z().value()], + [1, 0, 0], + 0.45, + ) + ngl.shape.add( + "text", + [coord1.x().value(), coord1.y().value(), coord1.z().value() - 0.9], + [0, 0, 0], + 2.5, + f"{core_atom_1}", + ) + + for i in range(2): + c00 = ligand_bound_core_coords[i] + coord00 = [c00.x().value(), c00.y().value(), c00.z().value()] + c01 = ligand_bound_core_coords[i + 1] + coord01 = [c01.x().value(), c01.y().value(), c01.z().value()] + start, end = find_arrow_points(coord00, coord01, 0.9, 0.9) + ngl.shape.add_arrow( + start, + end, + [0, 0, 0], + 0.1, + ) + c10 = ligand_free_core_coords[i] + coord10 = [c10.x().value(), c10.y().value(), c10.z().value()] + c11 = ligand_free_core_coords[i + 1] + coord11 = [c11.x().value(), c11.y().value(), c11.z().value()] + start, end = find_arrow_points(coord10, coord11, 0.9, 0.9) + ngl.shape.add_arrow( + start, + end, + [0, 0, 0], + 0.1, + ) + + if system is not None: + del local_s + return ngl + + +class ATM: + """ + A class for setting up, running, and analysis RBFE calculations using the + Alchemical Transfer Method. + """ + + def __init__( + self, + system, + protocol, + platform="CPU", + work_dir=None, + setup_only=False, + property_map={}, + ): + """ + Constructor. + + Parameters + ---------- + + system : BioSimSpace._SireWrappers.System + A prepared ATM system containing a protein and two ligands, one bound and one free. + Assumed to already be equilibrated. + + protocol : BioSimSpace.Protocol.ATM + The ATM protocol to use for the simulation. + + platform : str + The platform for the simulation: “CPU”, “CUDA”, or “OPENCL”. + For CUDA use the CUDA_VISIBLE_DEVICES environment variable to set the GPUs on which to run, + e.g. to run on two GPUs indexed 0 and 1 use: CUDA_VISIBLE_DEVICES=0,1. + For OPENCL, instead use OPENCL_VISIBLE_DEVICES. + + work_dir : str + The working directory for the simulation. + + setup_only : bool + Whether to only support simulation setup. If True, then no + simulation processes objects will be created, only the directory + hierarchy and input files to run a simulation externally. This + can be useful when you don't intend to use BioSimSpace to run + the simulation. Note that a 'work_dir' must also be specified. + + property_map : dict + A dictionary that maps system "properties" to their user defined + values. This allows the user to refer to properties with their + own naming scheme, e.g. { "charge" : "my-charge" } + + """ + + self._system = system.copy() + + # Validate the protocol. + if protocol is not None: + from ..Protocol._atm import ATMProduction as _Production + + if not isinstance(protocol, _Production): + raise TypeError( + "'protocol' must be of type 'BioSimSpace.Protocol.ATMProduction'" + ) + else: + self._protocol = protocol + else: + # No default protocol due to the need for well-defined rigid cores + raise ValueError("A protocol must be specified") + + # Check the platform. + if not isinstance(platform, str): + raise TypeError("'platform' must be of type 'str'.") + else: + self._platform = platform + + if not isinstance(setup_only, bool): + raise TypeError("'setup_only' must be of type 'bool'.") + else: + self._setup_only = setup_only + + if work_dir is None and setup_only: + raise ValueError( + "A 'work_dir' must be specified when 'setup_only' is True!" + ) + + # Create the working directory. + self._work_dir = _Utils.WorkDir(work_dir) + + # Check that the map is valid. + if not isinstance(property_map, dict): + raise TypeError("'property_map' must be of type 'dict'") + self._property_map = property_map + + self._inititalise_runner(system=self._system) + + def run(self, serial=True): + """ + Run the simulations. + + serial : bool + Whether to run the individual processes for the lambda windows + """ + if not isinstance(serial, bool): + raise TypeError("'serial' must be of type 'bool'.") + + if self._setup_only: + _warnings.warn("No processes exist! Object created in 'setup_only' mode.") + else: + self._runner.startAll(serial=serial) + + def wait(self): + """Wait for the simulation to finish.""" + if self._setup_only: + _warnings.warn("No processes exist! Object created in 'setup_only' mode.") + else: + self._runner.wait() + + def kill(self, index): + """ + Kill a process for a specific lambda window. + + Parameters + ---------- + + index : int + The index of the lambda window. + """ + self._runner.kill(index) + + def killAll(self): + """Kill any running processes for all lambda windows.""" + + self._runner.killAll() + + def workDir(self): + """ + Return the working directory. + + Returns + ------- + + work_dir : str + The path of the working directory. + """ + return str(self._work_dir) + + def getData(self, name="data", file_link=False, work_dir=None): + """ + Return a link to a zip file containing the data files required for + post-simulation analysis. + + Parameters + ---------- + + name : str + The name of the zip file. + + file_link : bool + Whether to return a FileLink when working in Jupyter. + + work_dir : str + The working directory for the free-energy perturbation + simulation. + + Returns + ------- + + output : str, IPython.display.FileLink + A path, or file link, to an archive of the process input. + """ + + if self._work_dir is None: + raise ValueError("'work_dir' must be set!") + else: + if not isinstance(work_dir, str): + raise TypeError("'work_dir' must be of type 'str'.") + if not _os.path.isdir(work_dir): + raise ValueError("'work_dir' doesn't exist!") + + if not isinstance(name, str): + raise TypeError("'name' must be of type 'str'") + + # Generate the zip file name. + zipname = "%s.zip" % name + + # Get the current working directory. + cwd = _os.getcwd() + + # Change into the working directory. + with _cd(work_dir): + # Specify the path to glob. + glob_path = _pathlib.Path(work_dir) + + # First try SOMD data. + files = glob_path.glob("**/gradients.dat") + + if len(files) == 0: + files = glob_path.glob("**/[!bar]*.xvg") + + if len(files) == 0: + raise ValueError( + f"Couldn't find any analysis files in '{work_dir}'" + ) + + # Write to the zip file. + with _zipfile.Zipfile(_os.join(cwd, zipname), "w") as zip: + for file in files: + zip.write(file) + + # Return a link to the archive. + if _is_notebook: + if file_link: + # Create a FileLink to the archive. + f_link = _FileLink(zipname) + + # Set the download attribute so that JupyterLab doesn't try to open the file. + f_link.html_link_str = ( + f"%s" + ) + + # Return a link to the archive. + return f_link + else: + return zipname + # Return the path to the archive. + else: + return zipname + + def _inititalise_runner(self, system): + """ + Internal helper function to initialise the process runner. + + Parameters + ---------- + + system : :class:`System ` + The molecular system. + """ + + # This protocol will have to be minimal - cannot guess rigid core atoms + if self._protocol is None: + raise RuntimeError("No protocol has been set - cannot run simulations.") + # Initialise list to store the processe + processes = [] + # Get the list of lambda1 values so that the total number of simulations can + # be asserted + lambda_list = self._protocol._get_lambda_values() + # Set index of current simulation to 0 + self._protocol.set_current_index(0) + lam = lambda_list[0] + + first_dir = "%s/lambda_%5.4f" % (self._work_dir, lam) + + # Create the first simulation, which will be copied and used for future simulations. + first_process = _OpenMM( + system=system, + protocol=self._protocol, + platform=self._platform, + work_dir=first_dir, + property_map=self._property_map, + ) + + if self._setup_only: + del first_process + else: + processes.append(first_process) + + # Remove first index as its already been used + lambda_list = lambda_list[1:] + # Enumerate starting at 1 to account for the removal of the first lambda value + for index, lam in enumerate(lambda_list, 1): + # Files are named according to index, rather than lambda value + # This is to avoid confusion arising from the fact that there are multiple lambdas + # and that the values of lambda1 and lambda2 wont necessarily be go from 0 to 1 + # and may contain duplicates + new_dir = "%s/lambda_%5.4f" % (self._work_dir, lam) + # Use absolute path. + if not _os.path.isabs(new_dir): + new_dir = _os.path.abspath(new_dir) + + # Delete any existing directories. + if _os.path.isdir(new_dir): + _shutil.rmtree(new_dir, ignore_errors=True) + + # Copy the first directory to that of the current lambda value. + _shutil.copytree(first_dir, new_dir) + # For speed reasons, additional processes need to be created by copying the first process. + # this is more difficult than usual due to the number of window-dependent variables + new_config = [] + # All variables that need to change + new_lam_1 = self._protocol.getLambda1()[index] + new_lam_2 = self._protocol.getLambda2()[index] + new_alpha = self._protocol.getAlpha()[index].value() + new_uh = self._protocol.getUh()[index].value() + new_w0 = self._protocol.getW0()[index].value() + new_direction = self._protocol.getDirection()[index] + with open(new_dir + "/openmm_script.py", "r") as f: + for line in f: + if line.startswith("lambda1"): + new_config.append(f"lambda1 = {new_lam_1}\n") + elif line.startswith("lambda2"): + new_config.append(f"lambda2 = {new_lam_2}\n") + elif line.startswith("alpha"): + new_config.append( + f"alpha = {new_alpha} * kilocalories_per_mole\n" + ) + elif line.startswith("uh"): + new_config.append(f"uh = {new_uh} * kilocalories_per_mole\n") + elif line.startswith("w0"): + new_config.append(f"w0 = {new_w0} * kilocalories_per_mole\n") + elif line.startswith("direction"): + new_config.append(f"direction = {new_direction}\n") + elif line.startswith("window_index"): + new_config.append(f"window_index = {index}\n") + else: + new_config.append(line) + with open(new_dir + "/openmm_script.py", "w") as f: + for line in new_config: + f.write(line) + + # Create a new process object for the current lambda value and append + # to the list of processes + if not self._setup_only: + process = _copy.copy(first_process) + process._system = first_process._system.copy() + process._protocol = self._protocol + process._work_dir = new_dir + process._stdout_file = new_dir + "/ATM.out" + process._stderr_file = new_dir + "/ATM.err" + process._rst_file = new_dir + "/openmm.rst7" + process._top_file = new_dir + "/openmm.prm7" + process._traj_file = new_dir + "/openmm.dcd" + process._config_file = new_dir + "/openmm_script.py" + process._input_files = [ + process._config_file, + process._rst_file, + process._top_file, + ] + processes.append(process) + + if not self._setup_only: + # Initialise process runner. + self._runner = _ProcessRunner(processes) + + @staticmethod + def analyse( + work_dir, + method="UWHAM", + ignore_lower=0, + ignore_upper=None, + inflex_indices=None, + ): + """Analyse the ATM simulation. + + Parameters + ---------- + + work_dir : str + The working directory where the ATM simulation is located. + + method : str + The method to use for the analysis. Currently only UWHAM is supported. + + ignore_lower : int + Ignore the first N samples when analysing. + + inflex_indices : [int] + The indices at which the direction changes. For example, if direction=[1,1,-1,-1], + then inflex_indices=[1,2]. + If None, the inflexion point will be found automatically. + + Returns + ------- + + ddg : :class:`BioSimSpace.Types.Energy` + The free energy difference between the two ligands. + + ddg_err :class:`BioSimSpace.Types.Energy` + The error in the free energy difference. + """ + if not isinstance(ignore_lower, int): + raise TypeError("'ignore_lower' must be an integer.") + if ignore_lower < 0: + raise ValueError("'ignore_lower' must be a positive integer.") + if ignore_upper is not None: + if not isinstance(ignore_upper, int): + raise TypeError("'ignore_upper' must be an integer.") + if ignore_upper < 0: + raise ValueError("'ignore_upper' must be a positive integer.") + if ignore_upper < ignore_lower: + raise ValueError( + "'ignore_upper' must be greater than or equal to 'ignore_lower'." + ) + if inflex_indices is not None: + if not isinstance(inflex_indices, list): + raise TypeError("'inflex_indices' must be a list.") + if not all(isinstance(x, int) for x in inflex_indices): + raise TypeError("'inflex_indices' must be a list of integers.") + if not len(inflex_indices) == 2: + raise ValueError("'inflex_indices' must have length 2.") + if method == "UWHAM": + total_ddg, total_ddg_err = ATM._analyse_UWHAM( + work_dir, ignore_lower, ignore_upper, inflex_indices + ) + return total_ddg, total_ddg_err + if method == "MBAR": + from ._relative import Relative as _Relative + + # temporary version to check that things are working + ddg_forward, ddg_reverse = ATM._analyse_MBAR(work_dir) + ddg_forward = _Relative.difference(ddg_forward) + ddg_reverse = _Relative.difference(ddg_reverse) + return ddg_forward, ddg_reverse + else: + raise ValueError(f"Method {method} is not supported for analysis.") + + @staticmethod + def _analyse_UWHAM(work_dir, ignore_lower, ignore_upper, inflex_indices=None): + """ + Analyse the UWHAM results from the ATM simulation. + """ + from ._ddg import analyse_UWHAM as _UWHAM + + total_ddg, total_ddg_err = _UWHAM( + work_dir, ignore_lower, ignore_upper, inflection_indices=inflex_indices + ) + return total_ddg, total_ddg_err + + @staticmethod + def _analyse_MBAR(work_dir): + """ + Analyse the MBAR results from the ATM simulation. + """ + from ._ddg import analyse_MBAR as _MBAR + + ddg_forward, ddg_reverse = _MBAR(work_dir) + return ddg_forward, ddg_reverse + + @staticmethod + def _analyse_test(work_dir): + """ + Analyse the test results from the ATM simulation. + """ + from ._ddg import new_MBAR as _test + + ddg_forward, ddg_reverse = _test(work_dir) + return ddg_forward, ddg_reverse + + @staticmethod + def _analyse_femto(work_dir): + from ._ddg import MBAR_hijack_femto + + est, o = MBAR_hijack_femto(work_dir) + return est, o + + +class _ViewAtoM(_View): + """ + Overloads regular view class needed to pass default_representation=False + into show_file. + """ + + # Initialise super + def __init__(self, handle, property_map={}, is_lambda1=False): + super().__init__(handle, property_map, is_lambda1) + + def _create_view(self, system=None, view=None, gui=True, **kwargs): + + if system is None and view is None: + raise ValueError("Both 'system' and 'view' cannot be 'None'.") + + elif system is not None and view is not None: + raise ValueError("One of 'system' or 'view' must be 'None'.") + + # Make sure gui flag is valid. + if gui not in [True, False]: + gui = True + + # Default to the most recent view. + if view is None: + index = self._num_views + else: + index = view + + # Create the file name. + filename = "%s/view_%04d.pdb" % (self._work_dir, index) + + # Increment the number of views. + if view is None: + self._num_views += 1 + + # Create a PDB object and write to file. + if system is not None: + try: + pdb = _SireIO.PDB2(system, self._property_map) + pdb.writeToFile(filename) + except Exception as e: + msg = "Failed to write system to 'PDB' format." + if _isVerbose(): + print(msg) + raise IOError(e) from None + else: + raise IOError(msg) from None + + # Import NGLView when it is used for the first time. + import nglview as _nglview + + # Create the NGLview object. + view = _nglview.show_file(filename, default_representation=False) + + # Return the view and display it. + return view.display(gui=gui) diff --git a/python/BioSimSpace/FreeEnergy/_ddg.py b/python/BioSimSpace/FreeEnergy/_ddg.py new file mode 100644 index 000000000..fa75d6dd1 --- /dev/null +++ b/python/BioSimSpace/FreeEnergy/_ddg.py @@ -0,0 +1,508 @@ +###################################################################### +# BioSimSpace: Making biomolecular simulation a breeze! +# +# Copyright: 2017-2024 +# +# Authors: Lester Hedges +# Matthew Burman +# +# BioSimSpace is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# BioSimSpace is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with BioSimSpace. If not, see . +###################################################################### + + +# Alchemical transfer analysis methods. UWHAM implementation adapted from +# both the `femto` and `ATM-openmm` packages. +__all__ = ["analyse_UWHAM", "analyse_MBAR"] + + +import functools as _functools +import numpy as _numpy +import os as _os +import pandas as _pd +import pathlib as _pathlib +import scipy.optimize as _optimize +import scipy.special as _special +import warnings as _warnings + + +def _compute_weights(ln_z, ln_q, factor): + q_ij = _numpy.exp(ln_q - ln_z) + return q_ij / (factor * q_ij).sum(axis=-1, keepdims=True) + + +def _compute_kappa_hessian(ln_z, ln_q, factor, n): + ln_z = _numpy.insert(ln_z, 0, 0.0) + + w = (factor * _compute_weights(ln_z, ln_q, factor))[:, 1:] + return -w.T @ w / n + _numpy.diag(w.sum(axis=0) / n) + + +def _compute_kappa(ln_z, ln_q, factor, n): + ln_z = _numpy.insert(ln_z, 0, 0.0) + + ln_q_ij_sum = _special.logsumexp(a=ln_q - ln_z, b=factor, axis=1) + kappa = ln_q_ij_sum.sum() / n + (factor * ln_z).sum() + + w = factor * _compute_weights(ln_z, ln_q, factor) + grad = -w[:, 1:].sum(axis=0) / n + factor[1:] + + return kappa, grad + + +def _compute_variance(ln_z, w, factor, n): + o = w.T @ w / n + + b = o * factor - _numpy.eye(len(ln_z)) + b = b[1:, 1:] + + b_inv_a = -o + o[0, :] + b_inv_a = b_inv_a[1:, 1:] + + var_matrix = (b_inv_a @ _numpy.linalg.inv(b.T)) / n + return _numpy.insert(_numpy.diag(var_matrix), 0, 0.0) + + +def _bias_fcn(epert, lam1, lam2, alpha, u0, w0): + """ + This is for the bias ilogistic potential + (lambda2-lambda1) ln[1+exp(-alpha (u-u0))]/alpha + lambda2 u + w0 + """ + ebias1 = _numpy.zeros_like(epert) + if alpha > 0: + ee = 1 + _numpy.exp(-alpha * (epert - u0)) + ebias1 = (lam2 - lam1) * _numpy.log(ee) / alpha + return ebias1 + lam2 * epert + w0 + + +def _npot_fcn(e0, epert, bet, lam1, lam2, alpha, u0, w0): + # This is the negative reduced energy + # -beta*(U0+bias) + return -bet * (e0 + _bias_fcn(epert, lam1, lam2, alpha, u0, w0)) + + +def _estimate_f_i(ln_q, n_k): + """Estimates the free energies of a set of *sampled* states. + + + Args: + n_k: The number of samples at state ``k``. + ln_q: array of netgative potentials with ``shape=(n_states,n_samples)``. + + Returns: + The estimated reduced free energies and their estimated variance. + """ + n_k = _numpy.array(n_k) + + ln_q = _numpy.array(ln_q).T + + n_samples, n_states = ln_q.shape + + if n_states != len(n_k): + raise RuntimeError( + "The number of states do not match: %d != %d" % (n_states, len(n_k)) + ) + if n_samples != n_k.sum(): + raise RuntimeError( + "The number of samples do not match: %d != %d" % (n_samples, n_k.sum()) + ) + + ln_z = _numpy.zeros(len(n_k) - 1) # ln_z_0 is always fixed at 0.0 + ln_q -= ln_q[:, :1] + + n = n_k.sum() + factor = n_k / n + + result = _optimize.minimize( + _functools.partial(_compute_kappa, ln_q=ln_q, n=n, factor=factor), + ln_z, + method="trust-ncg", + jac=True, + hess=_functools.partial(_compute_kappa_hessian, ln_q=ln_q, n=n, factor=factor), + ) + + if not result.success: + raise RuntimeError("The UWHAM minimization failed to converge.") + + f_i = _numpy.insert(-result.x, 0, 0.0) + ln_z = _numpy.insert(result.x, 0, 0.0) + + weights = _compute_weights(ln_z, ln_q, factor) + + if not _numpy.allclose(weights.sum(axis=0) / n, 1.0, atol=1e-2): + w = weights.sum(axis=0) / n + _warnings.warn(f"The UWHAM weights do not sum to 1.0 ({w})") + + df_i = _compute_variance(ln_z, weights, factor, n) + + return f_i, df_i, weights / n + + +def _sort_folders(work_dir): + """Sorts folder names by lambda value, ensuring they are read correctly. + + Parameters + ---------- + + work_dir : str + The directory containing the simulation data. + + Returns + ------- + + folders : dict + A dictionary of folder names and their corresponding lambda values. + """ + folders = {} + for folder in _pathlib.Path(work_dir).iterdir(): + if folder.is_dir() and folder.name.startswith("lambda_"): + try: + lambda_val = float(folder.name.split("_")[-1]) + except ValueError: + continue + folders[lambda_val] = folder + return {k: v for k, v in sorted(folders.items())} + + +def _get_inflection_indices(folders): + # Find folders at which 'direction' goes from 1 to -1 + # This is the point at which the direction of the lambda windows changes + # NOTE: this assumes that the folders are correctly sorted + + # check that the keys are sorted + keys = list(folders.keys()) + if keys != sorted(keys): + raise ValueError(f"Folders are not sorted correctly. {keys} != {sorted(keys)}") + + directions = [] + for folder in folders.values(): + df = _pd.read_csv(folder / "openmm.csv") + direction = df["direction"].values[0] + directions.append(direction) + + # get the indices at which the direction changes + for i in range(len(directions) - 1): + if directions[i] != directions[i + 1]: + inflection_indices = (i, i + 1) + break + + return inflection_indices + + +def analyse_UWHAM(work_dir, ignore_lower, ignore_upper, inflection_indices=None): + """ + Analyse the output of BioSimSpace ATM simulations. + + Parameters + ---------- + + work_dir : str + The directory containing the simulation data. + + ignore_lower : int + The number of rows to ignore at the start of each file. + + ignore_upper : int + The number of rows to ignore at the end of each file. + + inflection_indices : tuple, optional + The point at which 'direction' changes. + Should be (last index of direction 1, first index of direction 2). + If not provided not provided, will be implied from files. + + Returns + ------- + + ddg_total : :class:`BioSimSpace.Types.Energy` + The free energy. + + ddg_total_error : :class:`BioSimSpace.Types.Energy` + The error in the free energy. + """ + # NOTE: This code is not designed to work with repex + # It always assumes that each window is at the same temperature + dataframes = [] + slices = {} + total_states = 0 + total_samples = 0 + folders = _sort_folders(work_dir) + if inflection_indices is None: + inflection_indices = _get_inflection_indices(folders) + for folder in folders.values(): + df = _pd.read_csv(folder / "openmm.csv") + # drop the first `ignore_lower` rows of each df + if ignore_upper is not None: + df = df.iloc[ignore_lower:ignore_upper] + else: + df = df.iloc[ignore_lower:] + # Beta values, assuming that energies are in kj/mol + df["beta"] = 1 / (0.001986209 * df["temperature"]) + total_states += 1 + total_samples += len(df) + for sub_df in df.groupby("window"): + # get value of window + window = sub_df[0] + # check if window is in slices + if window not in slices: + slices[window] = [] + # append the dataframe to the list of dataframes for that window + # now get the tuple 'sub_df' and convert it to a dataframe + s = sub_df[1] + slices[window].append(s) + + # now combine all dataframes in each slice + for window in slices: + # get the dataframes for the current window + dfs = slices[window] + # combine the dataframes + combined_df = _pd.concat(dfs) + dataframes.append(combined_df) + + # sort 'dataframes' based on 'window' + dataframes = sorted(dataframes, key=lambda x: x["window"].values[0]) + + pots = [] + pert_es = [] + n_samples = [] + # check that all dataframes are the same length, throw a warning if they are not + for df in dataframes: + n_samples.append(len(df)) + e0 = df["pot_en"].values + pert_e = df["pert_en"].values + pots.append(e0) + pert_es.append(pert_e) + + # Should only matter in cases where states are at different temps, + # leaving here for debugging and parity with GL code + for be in range(len(n_samples)): + pots[be] = pots[be] - _bias_fcn( + pert_es[be], + lam1=dataframes[be]["lambda1"].values[0], + lam2=dataframes[be]["lambda2"].values[0], + alpha=dataframes[be]["alpha"].values[0], + u0=dataframes[be]["uh"].values[0], + w0=dataframes[be]["w0"].values[0], + ) + # We will assume that the point at which leg1 and leg2 are split is halfway through + n_samples_first_half = n_samples[: inflection_indices[0] + 1] + pots_first_half = _numpy.concatenate(pots[: inflection_indices[0] + 1]) + pert_es_first_half = _numpy.concatenate(pert_es[: inflection_indices[0] + 1]) + ln_q = _numpy.zeros((inflection_indices[0] + 1, len(pots_first_half))) + sid = 0 + + for be in range(len(n_samples_first_half)): + lnq = _npot_fcn( + e0=pots_first_half, + epert=pert_es_first_half, + bet=dataframes[be]["beta"].values[0], + lam1=dataframes[be]["lambda1"].values[0], + lam2=dataframes[be]["lambda2"].values[0], + alpha=dataframes[be]["alpha"].values[0], + u0=dataframes[be]["uh"].values[0], + w0=dataframes[be]["w0"].values[0], + ) + ln_q[sid] = lnq + sid += 1 + f_i, d_i, weights = _estimate_f_i(ln_q, n_samples_first_half) + ddg = f_i[-1] - f_i[0] + ddg1 = ddg / dataframes[0]["beta"].values[0] + # print(f"Forward leg: {ddg1}") + ddg_error_1 = _numpy.sqrt(d_i[-1] + d_i[0]) / dataframes[0]["beta"].values[0] + + n_samples_second_half = n_samples[inflection_indices[1] :] + pots_second_half = _numpy.concatenate(pots[inflection_indices[1] :]) + pert_es_second_half = _numpy.concatenate(pert_es[inflection_indices[1] :]) + ln_q = _numpy.zeros((total_states - inflection_indices[1], len(pots_second_half))) + sid = 0 + + # note the order of (be, te) + for be in range(len(n_samples_second_half)): + lnq = _npot_fcn( + e0=pots_second_half, + epert=pert_es_second_half, + bet=dataframes[be]["beta"].values[0], + lam1=dataframes[be]["lambda1"].values[0], + lam2=dataframes[be]["lambda2"].values[0], + alpha=dataframes[be]["alpha"].values[0], + u0=dataframes[be]["uh"].values[0], + w0=dataframes[be]["w0"].values[0], + ) + ln_q[sid] = lnq + sid += 1 + f_i, d_i, weights = _estimate_f_i(ln_q, n_samples_second_half) + ddg = f_i[-1] - f_i[0] + ddg2 = ddg / dataframes[0]["beta"].values[0] + # print(f"Reverse leg: {ddg2}") + ddg_error_2 = _numpy.sqrt(d_i[-1] + d_i[0]) / dataframes[0]["beta"].values[0] + + ddg_total = ddg1 - ddg2 + ddg_total_error = _numpy.sqrt(ddg_error_1**2 + ddg_error_2**2) + from BioSimSpace.Units import Energy as _Energy + + ddg_total = ddg_total * _Energy.kcal_per_mol + ddg_total_error = ddg_total_error * _Energy.kcal_per_mol + + return ddg_total, ddg_total_error + + +def analyse_MBAR(work_dir): + """ + Analyse the MBAR-compatible outputs. + Adapted version of BioSimSpace _analyse_internal function + """ + from ._relative import Relative as _Relative + from alchemlyb.postprocessors.units import to_kcalmol as _to_kcalmol + from .. import Units as _Units + + try: + from alchemlyb.estimators import AutoMBAR as _AutoMBAR + except ImportError: + from alchemlyb.estimators import MBAR as _AutoMBAR + + if not isinstance(work_dir, str): + raise TypeError("work_dir must be a string") + if not _os.path.isdir(work_dir): + raise ValueError("work_dir must be a valid directory") + + glob_path = _pathlib.Path(work_dir) + files = sorted(glob_path.glob("**/energies*.csv")) + + # Slightly more complicated than a standard FE calculation + # the key complication comes from the need to split the forward and reverse legs + # instead of being inherently separate as in a standard FE calculation, they + # are dictated by 'direction'. This means that the energy arrays need + # to be re-numbered in to separate forward and reverse legs. + + # need to make sure that all lambdas were run at the same temp + temps = [] + dataframes_forward = [] + dataframes_backward = [] + for file in files: + # read the csv to a dataframe + df = _pd.read_csv(file) + # read the temperature column and make sure all values in it are equal + temp = df["temperature"].unique() + if len(temp) != 1: + raise ValueError(f"Temperature column in {file} is not uniform") + # check if the last column in the dataframe is full of NaNs + if df.iloc[:, -1:].isnull().values.all(): + reverse = False + else: + reverse = True + temps.append(temp[0]) + # now drop the temperature column + df = df.drop(columns=["temperature"]) + # remove columns with NaN values + df = df.dropna(axis=1) + # we will need to match the fep-lambda value to the correct new value + # first get fep-lambda, should be the same value for all entries in the 'fep-lambda' column + fep_lambda = df["fep-lambda"].unique() + if len(fep_lambda) != 1: + raise ValueError(f"fep-lambda column in {file} is not uniform") + # find all columns whose titles are only numbers + cols = [] + num_lams = 0 + for col in df.columns: + try: + val = float(col) + if val == fep_lambda: + index_fep_lambda = num_lams + num_lams += 1 + except ValueError: + cols.append(col) + new_lambdas = list(_numpy.linspace(0, 1, num_lams)) + new_fep_lambda = new_lambdas[index_fep_lambda] + new_cols = cols + new_lambdas + # rename the columns + df.columns = new_cols + # now replace all values in the fep-lambda column with the new value + df["fep-lambda"] = new_fep_lambda + df.set_index(cols, inplace=True) + if reverse: + dataframes_backward.append(df) + else: + dataframes_forward.append(df) + + # check that all temperatures are the same + if len(set(temps)) != 1: + raise ValueError("All temperatures must be the same") + data_forward = _Relative._preprocess_data(dataframes_forward, "MBAR") + data_backward = _Relative._preprocess_data(dataframes_backward, "MBAR") + print("\n\n\n\n\n") + print(type(data_forward)) + data_forward.attrs = { + "temperature": temps[0], + "energy_unit": "kJ/mol", + } + data_backward.attrs = { + "temperature": temps[0], + "energy_unit": "kJ/mol", + } + try: + alchem_forward = _AutoMBAR().fit(data_forward) + except ValueError as e: + raise ValueError(f"Error in fitting forward leg of MBAR calculation: {e}") + + try: + alchem_backward = _AutoMBAR().fit(data_backward) + except ValueError as e: + raise ValueError(f"Error in fitting backward leg of MBAR calculation: {e}") + + alchem_forward.delta_f_.attrs = { + "temperature": temps[0], + "energy_unit": "kJ/mol", + } + delta_f_for = _to_kcalmol(alchem_forward.delta_f_) + alchem_forward.d_delta_f_.attrs = { + "temperature": temps[0], + "energy_unit": "kJ/mol", + } + d_delta_f_for = _to_kcalmol(alchem_forward.d_delta_f_) + data_forward_final = [] + for lamb in new_lambdas: + x = new_lambdas.index(lamb) + mbar_value = delta_f_for.iloc[0, x] + mbar_error = d_delta_f_for.iloc[0, x] + + data_forward_final.append( + ( + lamb, + (mbar_value) * _Units.Energy.kcal_per_mol, + (mbar_error) * _Units.Energy.kcal_per_mol, + ) + ) + alchem_backward.delta_f_.attrs = { + "temperature": temps[0], + "energy_unit": "kJ/mol", + } + delta_f_back = _to_kcalmol(alchem_backward.delta_f_) + alchem_backward.d_delta_f_.attrs = { + "temperature": temps[0], + "energy_unit": "kJ/mol", + } + d_delta_f_back = _to_kcalmol(alchem_backward.d_delta_f_) + data_backward_final = [] + for lamb in new_lambdas: + x = new_lambdas.index(lamb) + mbar_value = delta_f_back.iloc[0, x] + mbar_error = d_delta_f_back.iloc[0, x] + + data_backward_final.append( + ( + lamb, + (mbar_value) * _Units.Energy.kcal_per_mol, + (mbar_error) * _Units.Energy.kcal_per_mol, + ) + ) + + return data_forward_final, data_backward_final diff --git a/python/BioSimSpace/Process/_atm.py b/python/BioSimSpace/Process/_atm.py new file mode 100644 index 000000000..8c27f537e --- /dev/null +++ b/python/BioSimSpace/Process/_atm.py @@ -0,0 +1,865 @@ +###################################################################### +# BioSimSpace: Making biomolecular simulation a breeze! +# +# Copyright: 2017-2024 +# +# Authors: Lester Hedges +# Matthew Burman +# +# BioSimSpace is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# BioSimSpace is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with BioSimSpace. If not, see . +###################################################################### + +import math as _math +import warnings as _warnings + +from .._Exceptions import IncompatibleError as _IncompatibleError +from .. import Protocol as _Protocol +from ._atm_utils import _ATMUtils +from ._openmm import OpenMM as _OpenMM + + +class OpenMMATM(_OpenMM): + """ + Derived class for running ATM simulations using OpenMM. Overloads the + _generate_config() to introduce ATM-specific methods. + """ + + def __init__( + self, + system=None, + protocol=None, + reference_system=None, + exe=None, + name="openmm", + platform="CPU", + work_dir=None, + seed=None, + property_map={}, + **kwargs, + ): + # Look for the is_testing flag in the kwargs. + # Only used for calculating single point energies. + if "_is_testing" in kwargs: + _warnings.warn("NOW IN TESTING MODE") + self._is_testing = kwargs["_is_testing"] + else: + self._is_testing = False + super().__init__( + system, + protocol, + reference_system=reference_system, + exe=exe, + name=name, + platform=platform, + work_dir=work_dir, + seed=seed, + property_map=property_map, + **kwargs, + ) + + def _generate_config(self): + if isinstance(self._protocol, _Protocol.ATMMinimisation): + self._generate_config_minimisation() + elif isinstance(self._protocol, _Protocol.ATMEquilibration): + self._generate_config_equilibration() + elif isinstance(self._protocol, _Protocol.ATMAnnealing): + self._generate_config_annealing() + elif isinstance(self._protocol, _Protocol.ATMProduction) and self._is_testing: + self._generate_config_single_point_testing() + elif isinstance(self._protocol, _Protocol.ATMProduction): + self._generate_config_production() + + def _check_space(self): + # Get the "space" property from the user mapping. + prop = self._property_map.get("space", "space") + + # Check whether the system contains periodic box information. + if prop in self._system._sire_object.propertyKeys(): + try: + # Make sure that we have a periodic box. The system will now have + # a default cartesian space. + box = self._system._sire_object.property(prop) + has_box = box.isPeriodic() + except: + has_box = False + else: + _warnings.warn("No simulation box found. Assuming gas phase simulation.") + has_box = False + + return has_box + + def _add_initialisation(self, has_box): + # Write the OpenMM import statements. + # Load the input files. + self.addToConfig("\n# Load the topology and coordinate files.") + self.addToConfig( + "\n# We use ParmEd due to issues with the built in AmberPrmtopFile for certain triclinic spaces." + ) + self.addToConfig( + f"prm = parmed.load_file('{self._name}.prm7', '{self._name}.rst7')" + ) + + # Don't use a cut-off if this is a vacuum simulation or if box information + # is missing. + self.addToConfig("\n# Initialise the molecular system.") + is_periodic = True + if not has_box or not self._has_water: + is_periodic = False + self.addToConfig("system = prm.createSystem(nonbondedMethod=NoCutoff,") + else: + self.addToConfig("system = prm.createSystem(nonbondedMethod=PME,") + self.addToConfig(" nonbondedCutoff=1*nanometer,") + self.addToConfig(" constraints=HBonds)") + + # Set the integrator. (Use zero-temperature as this is just a dummy step.) + self.addToConfig("\n# Define the integrator.") + self.addToConfig("integrator = LangevinMiddleIntegrator(0*kelvin,") + self.addToConfig(" 1/picosecond,") + self.addToConfig(" 0.002*picoseconds)") + + return is_periodic + + def _add_pressure_check(self, pressure, temperature, is_periodic): + # Add a Monte Carlo barostat if the simulation is at constant pressure. + is_constant_pressure = False + if pressure is not None: + # Cannot use a barostat with a non-periodic system. + if not is_periodic: + _warnings.warn( + "Cannot use a barostat for a vacuum or non-periodic simulation" + ) + else: + is_constant_pressure = True + + # Convert to bar and get the value. + pressure = pressure.bar().value() + + # Create the barostat and add its force to the system. + self.addToConfig("\n# Add a barostat to run at constant pressure.") + self.addToConfig( + f"barostat = MonteCarloBarostat({pressure}*bar, {temperature}*kelvin)" + ) + if self._is_seeded: + self.addToConfig(f"barostat.setRandomNumberSeed({self._seed})") + self.addToConfig("system.addForce(barostat)") + + return is_constant_pressure + + def _add_simulation_instantiation(self): + # Set up the simulation object. + self.addToConfig("\n# Initialise and configure the simulation object.") + self.addToConfig("simulation = Simulation(prm.topology,") + self.addToConfig(" system,") + self.addToConfig(" integrator,") + self.addToConfig(" platform,") + self.addToConfig(" properties)") + if self._protocol.getRestraint() is not None: + self.addToConfig("simulation.context.setPositions(positions)") + else: + self.addToConfig("simulation.context.setPositions(prm.positions)") + self.addToConfig("if prm.box_vectors is not None:") + self.addToConfig(" box_vectors = reducePeriodicBoxVectors(prm.box_vectors)") + self.addToConfig(" simulation.context.setPeriodicBoxVectors(*box_vectors)") + + def _generate_config_minimisation(self): + util = _ATMUtils(self._protocol) + # Clear the existing configuration list. + self._config = [] + + has_box = self._check_space() + self._add_config_imports() + self._add_config_monkey_patches() + self._add_initialisation(has_box) + + # Add the platform information. + self._add_config_platform() + + # Add any position restraints. + if self._protocol.getRestraint() is not None: + restraint = self._protocol.getRestraint() + # Search for the atoms to restrain by keyword. + if isinstance(restraint, str): + restrained_atoms = self._system.getRestraintAtoms(restraint) + # Use the user-defined list of indices. + else: + restrained_atoms = restraint + self.addToConfig("\n# Add position restraints.") + frc = util.create_flat_bottom_restraint(restrained_atoms) + self.addToConfig(frc) + + # Add the atom-specific restraints. + disp = util.createDisplacement() + self.addToConfig(disp) + if self._protocol.getCoreAlignment(): + alignment = util.createAlignmentForce() + self.addToConfig("\n# Add alignment force.") + self.addToConfig(alignment) + if self._protocol.getCOMDistanceRestraint(): + CMCM = util.createCOMRestraint() + self.addToConfig("\n# Add COM restraint.") + self.addToConfig(CMCM) + + self._add_simulation_instantiation() + + self.addToConfig( + f"simulation.minimizeEnergy(maxIterations={self._protocol.getSteps()})" + ) + # Add the reporters. + self.addToConfig("\n# Add reporters.") + self._add_config_reporters(state_interval=1, traj_interval=1) + + # Now run the simulation. + self.addToConfig( + "\n# Run a single simulation step to allow us to get the system and energy." + ) + self.addToConfig(f"simulation.step(1)") + + # Flag that this isn't a custom protocol. + self._protocol._setCustomised(False) + + def _generate_config_equilibration(self): + util = _ATMUtils(self._protocol) + # Clear the existing configuration list. + self._config = [] + + has_box = self._check_space() + self._add_config_imports() + self._add_config_monkey_patches() + is_periodic = self._add_initialisation(has_box) + + # Get the starting temperature and system pressure. + temperature = self._protocol.getStartTemperature().kelvin().value() + pressure = self._protocol.getPressure() + + is_constant_pressure = self._add_pressure_check( + pressure, temperature, is_periodic + ) + # Add any position restraints. + if self._protocol.getRestraint() is not None: + restraint = self._protocol.getRestraint() + # Search for the atoms to restrain by keyword. + if isinstance(restraint, str): + restrained_atoms = self._system.getRestraintAtoms(restraint) + # Use the user-defined list of indices. + else: + restrained_atoms = restraint + self.addToConfig("\n# Add position restraints.") + frc = util.create_flat_bottom_restraint(restrained_atoms) + self.addToConfig(frc) + + # Add the atom-specific restraints. + disp = util.createDisplacement() + self.addToConfig(disp) + if self._protocol.getUseATMForce(): + atm = util.createATMForce(index=None) + self.addToConfig(atm) + + if self._protocol.getCoreAlignment(): + alignment = util.createAlignmentForce() + self.addToConfig("\n# Add alignment force.") + self.addToConfig(alignment) + if self._protocol.getCOMDistanceRestraint(): + CMCM = util.createCOMRestraint() + self.addToConfig("\n# Add COM restraint.") + self.addToConfig(CMCM) + + # Get the integration time step from the protocol. + timestep = self._protocol.getTimeStep().picoseconds().value() + + # Set the integrator. + self.addToConfig("\n# Define the integrator.") + self.addToConfig(f"integrator = LangevinMiddleIntegrator({temperature}*kelvin,") + friction = 1 / self._protocol.getThermostatTimeConstant().picoseconds().value() + self.addToConfig(f" {friction:.5f}/picosecond,") + self.addToConfig(f" {timestep}*picoseconds)") + if self._is_seeded: + self.addToConfig(f"integrator.setRandomNumberSeed({self._seed})") + + # Add the platform information. + self._add_config_platform() + + self._add_simulation_instantiation() + + # Set initial velocities from temperature distribution. + self.addToConfig("\n# Setting initial system velocities.") + self.addToConfig( + f"simulation.context.setVelocitiesToTemperature({temperature})" + ) + + # Work out the number of integration steps. + steps = _math.ceil(self._protocol.getRunTime() / self._protocol.getTimeStep()) + + # Get the report and restart intervals. + report_interval = self._protocol.getReportInterval() + restart_interval = self._protocol.getRestartInterval() + + # Cap the intervals at the total number of steps. + if report_interval > steps: + report_interval = steps + if restart_interval > steps: + restart_interval = steps + + # Add the reporters. + self.addToConfig("\n# Add reporters.") + self._add_config_reporters( + state_interval=report_interval, + traj_interval=restart_interval, + is_restart=False, + ) + + # Now run the simulation. + self.addToConfig("\n# Run the simulation.") + + # Constant temperature equilibration. + if self._protocol.isConstantTemp(): + self.addToConfig(f"simulation.step({steps})") + + # Heating / cooling cycle. + else: + # Adjust temperature every 100 cycles, assuming that there at + # least that many cycles. + if steps > 100: + # Work out the number of temperature cycles. + temp_cycles = _math.ceil(steps / 100) + + # Work out the temperature change per cycle. + delta_temp = ( + self._protocol.getEndTemperature().kelvin().value() + - self._protocol.getStartTemperature().kelvin().value() + ) / temp_cycles + + self.addToConfig(f"start_temperature = {temperature}") + self.addToConfig(f"for x in range(0, {temp_cycles}):") + self.addToConfig(f" temperature = {temperature} + x*{delta_temp}") + self.addToConfig(f" integrator.setTemperature(temperature*kelvin)") + if is_constant_pressure: + self.addToConfig( + f" barostat.setDefaultTemperature(temperature*kelvin)" + ) + self.addToConfig(" simulation.step(100)") + else: + # Work out the temperature change per step. + delta_temp = ( + self._protocol.getEndTemperature().kelvin().value() + - self._protocol.getStartTemperature().kelvin().value() + ) / steps + + self.addToConfig(f"start_temperature = {temperature}") + self.addToConfig(f"for x in range(0, {steps}):") + self.addToConfig(f" temperature = {temperature} + x*{delta_temp}") + self.addToConfig(f" integrator.setTemperature(temperature*kelvin)") + if is_constant_pressure: + self.addToConfig( + f" barostat.setDefaultTemperature(temperature*kelvin)" + ) + self.addToConfig(" simulation.step(1)") + + def _generate_config_annealing(self): + self._protocol._set_current_index(0) + util = _ATMUtils(self._protocol) + # Clear the existing configuration list. + self._config = [] + + has_box = self._check_space() + + # Add standard openMM config + self.addToConfig("from glob import glob") + self.addToConfig("import math") + self.addToConfig("import os") + self.addToConfig("import shutil") + self._add_config_imports() + self._add_config_monkey_patches() + + is_periodic = self._add_initialisation(has_box) + + # Get the starting temperature and system pressure. + temperature = self._protocol.getTemperature().kelvin().value() + pressure = self._protocol.getPressure() + + is_constant_pressure = self._add_pressure_check( + pressure, temperature, is_periodic + ) + + # Add any position restraints. + if self._protocol.getRestraint() is not None: + restraint = self._protocol.getRestraint() + # Search for the atoms to restrain by keyword. + if isinstance(restraint, str): + restrained_atoms = self._system.getRestraintAtoms(restraint) + # Use the user-defined list of indices. + else: + restrained_atoms = restraint + self.addToConfig("\n# Add position restraints.") + frc = util.create_flat_bottom_restraint(restrained_atoms) + self.addToConfig(frc) + + # Use utils to create ATM-specific forces + # Atom force is the only window-dependent force + disp = util.createDisplacement() + self.addToConfig(disp) + self.addToConfig("\n# Add ATM Force.") + self.addToConfig(util.createATMForce(self._protocol._get_window_index())) + if self._protocol.getCoreAlignment(): + alignment = util.createAlignmentForce() + self.addToConfig("\n# Add alignment force.") + self.addToConfig(alignment) + + if self._protocol.getCOMDistanceRestraint(): + CMCM = util.createCOMRestraint() + self.addToConfig("\n# Add COM restraint.") + self.addToConfig(CMCM) + + # Get the integration time step from the protocol. + timestep = self._protocol.getTimeStep().picoseconds().value() + + # Set the integrator. + self.addToConfig("\n# Define the integrator.") + self.addToConfig(f"integrator = LangevinMiddleIntegrator({temperature}*kelvin,") + friction = 1 / self._protocol.getThermostatTimeConstant().picoseconds().value() + self.addToConfig(f" {friction:.5f}/picosecond,") + self.addToConfig(f" {timestep}*picoseconds)") + if self._is_seeded: + self.addToConfig(f"integrator.setRandomNumberSeed({self._seed})") + + # Add the platform information. + self._add_config_platform() + + self._add_simulation_instantiation() + + # Set initial velocities from temperature distribution. + self.addToConfig("\n# Setting initial system velocities.") + self.addToConfig( + f"simulation.context.setVelocitiesToTemperature({temperature})" + ) + + # Check for a restart file and load the simulation state. + is_restart, step = self._add_config_restart() + + # Work out the number of integration steps. + total_steps = _math.ceil( + self._protocol.getRunTime() / self._protocol.getTimeStep() + ) + + # Subtract the current number of steps. + steps = total_steps - step + + # Exit if the simulation has already finished. + if steps <= 0: + print("The simulation has already finished!") + return + + # Inform user that a restart was loaded. + self.addToConfig("\n# Print restart information.") + self.addToConfig("if is_restart:") + self.addToConfig(f" steps = {total_steps}") + self.addToConfig(" percent_complete = 100 * (step / steps)") + self.addToConfig(" print('Loaded state from an existing simulation.')") + self.addToConfig(" print(f'Simulation is {percent_complete}% complete.')") + + # Get the report and restart intervals. + report_interval = self._protocol.getReportInterval() + restart_interval = self._protocol.getRestartInterval() + + # Cap the intervals at the total number of steps. + if report_interval > steps: + report_interval = steps + if restart_interval > steps: + restart_interval = steps + + # Add the reporters. + self.addToConfig("\n# Add reporters.") + self._add_config_reporters( + state_interval=report_interval, + traj_interval=restart_interval, + is_restart=is_restart, + ) + + # Work out the total simulation time in picoseconds. + run_time = steps * timestep + + # Work out the number of cycles in 100 picosecond intervals. + cycles = _math.ceil(run_time / 100) + + # Work out the number of steps per cycle. + steps_per_cycle = int(steps / cycles) + + # get annealing protocol from atom utils + annealing_protocol = util.createAnnealingProtocol() + self.addToConfig(annealing_protocol) + + def _generate_config_production(self): + self._protocol.set_current_index(0) + analysis_method = self._protocol.getAnalysisMethod() + util = _ATMUtils(self._protocol) + # Clear the existing configuration list. + self._config = [] + + has_box = self._check_space() + + # TODO: check extra_options, extra_lines and property_map + if self._protocol.get_window_index() is None: + raise _IncompatibleError( + "ATM protocol requires the current window index to be set." + ) + + # Write the OpenMM import statements. + + self.addToConfig("import pandas as pd") + self.addToConfig("import numpy as np") + self.addToConfig("from glob import glob") + self.addToConfig("import math") + self.addToConfig("import os") + self.addToConfig("import shutil") + self._add_config_imports() + self._add_config_monkey_patches() + self.addToConfig("\n") + if analysis_method == "UWHAM" or analysis_method == "both": + self.addToConfig(util.createSoftcorePertE()) + # Add standard openMM config + + is_periodic = self._add_initialisation(has_box) + # Get the starting temperature and system pressure. + temperature = self._protocol.getTemperature().kelvin().value() + pressure = self._protocol.getPressure() + + is_constant_pressure = self._add_pressure_check( + pressure, temperature, is_periodic + ) + + # Add any position restraints. + if self._protocol.getRestraint() is not None: + restraint = self._protocol.getRestraint() + # Search for the atoms to restrain by keyword. + if isinstance(restraint, str): + restrained_atoms = self._system.getRestraintAtoms(restraint) + # Use the user-defined list of indices. + else: + restrained_atoms = restraint + self.addToConfig("\n# Add position restraints.") + frc = util.create_flat_bottom_restraint(restrained_atoms) + self.addToConfig(frc) + + # Use utils to create ATM-specific forces + # Atom force is the only window-dependent force + disp = util.createDisplacement() + self.addToConfig(disp) + self.addToConfig("\n# Add ATM Force.") + self.addToConfig(util.createATMForce(self._protocol.get_window_index())) + if self._protocol.getCoreAlignment(): + alignment = util.createAlignmentForce() + self.addToConfig("\n# Add alignment force.") + self.addToConfig(alignment) + + if self._protocol.getCOMDistanceRestraint(): + CMCM = util.createCOMRestraint() + self.addToConfig("\n# Add COM restraint.") + self.addToConfig(CMCM) + + # Get the integration time step from the protocol. + timestep = self._protocol.getTimeStep().picoseconds().value() + + # Set the integrator. + self.addToConfig("\n# Define the integrator.") + self.addToConfig(f"integrator = LangevinMiddleIntegrator({temperature}*kelvin,") + friction = 1 / self._protocol.getThermostatTimeConstant().picoseconds().value() + self.addToConfig(f" {friction:.5f}/picosecond,") + self.addToConfig(f" {timestep}*picoseconds)") + if self._is_seeded: + self.addToConfig(f"integrator.setRandomNumberSeed({self._seed})") + + # Add the platform information. + self._add_config_platform() + + self._add_simulation_instantiation() + + # Set initial velocities from temperature distribution. + self.addToConfig("\n# Setting initial system velocities.") + self.addToConfig( + f"simulation.context.setVelocitiesToTemperature({temperature})" + ) + + # Check for a restart file and load the simulation state. + is_restart, _ = self._add_config_restart() + + # NOTE: The restarting logic here is different to previous openMM classes + # It doesn't use the steps value from the restart function, instead + # the number of steps is worked out at runtime within the openmm script + # this means that restarting either by using the biosimspace runner + # OR by running the openmm script directly will work the same. + step = 0 + + # Work out the number of integration steps. + total_steps = _math.ceil( + self._protocol.getRunTime() / self._protocol.getTimeStep() + ) + + # Subtract the current number of steps. + steps = total_steps - step + + # Exit if the simulation has already finished. + if steps <= 0: + print("The simulation has already finished!") + return + + # Get the report and restart intervals. + report_interval = self._protocol.getReportInterval() + restart_interval = self._protocol.getRestartInterval() + + # Cap the intervals at the total number of steps. + if report_interval > steps: + report_interval = steps + if restart_interval > steps: + restart_interval = steps + + # Work out the total simulation time in picoseconds. + run_time = steps * timestep + + # Work out the number of cycles in 100 picosecond intervals. + cycles = _math.ceil(run_time / (report_interval * timestep)) + + # Work out the number of steps per cycle. + steps_per_cycle = int(steps / cycles) + + self.addToConfig( + util.createRestartLogic( + total_cycles=cycles, steps_per_cycle=steps_per_cycle + ) + ) + # Inform user that a restart was loaded. + self.addToConfig("\n# Print restart information.") + self.addToConfig("if is_restart:") + self.addToConfig(f" steps = {total_steps}") + self.addToConfig(" percent_complete = 100 * (step / steps)") + self.addToConfig(" print('Loaded state from an existing simulation.')") + self.addToConfig(" print(f'Simulation is {percent_complete}% complete.')") + self.addToConfig(" print(f'running an additional {numcycles} cycles')") + + # Add the reporters. + self.addToConfig("\n# Add reporters.") + self._add_config_reporters( + state_interval=report_interval, + traj_interval=restart_interval, + is_restart=is_restart, + ) + + self.addToConfig(f"\ntemperature = {temperature}") + if analysis_method == "UWHAM": + # Now run the simulation. + self.addToConfig( + util.createSoftcorePertELoop( + name=self._name, + steps_per_cycle=steps_per_cycle, + report_interval=report_interval, + timestep=timestep, + ) + ) + elif analysis_method == "both": + direction = self._protocol.getDirection() + inflex = 0 + for i in range(len(direction) - 1): + if direction[i] != direction[i + 1]: + inflex = i + 1 + break + # Now run the simulation. + self.addToConfig( + util.createReportingBoth( + name=self._name, + steps_per_cycle=steps_per_cycle, + timestep=timestep, + inflex_point=inflex, + ) + ) + else: + direction = self._protocol.getDirection() + inflex = 0 + for i in range(len(direction) - 1): + if direction[i] != direction[i + 1]: + inflex = i + 1 + break + self.addToConfig( + util.createLoopWithReporting( + name=self._name, + steps_per_cycle=steps_per_cycle, + report_interval=report_interval, + timestep=timestep, + inflex_point=inflex, + ) + ) + + def _generate_config_single_point_testing(self): + # Designed as a hidden method - uses a production protocol to + # calculate single point energies for each lambda window + # quite hacky, but not designed to be exposed to the user anyway + self._protocol.set_current_index(0) + if not isinstance(self._protocol, _Protocol.ATMProduction): + raise _IncompatibleError( + "Single point testing requires an ATMProduction protocol." + ) + util = _ATMUtils(self._protocol) + # Clear the existing configuration list. + self._config = [] + + has_box = self._check_space() + + # TODO: check extra_options, extra_lines and property_map + if self._protocol.get_window_index() is None: + raise _IncompatibleError( + "ATM protocol requires the current window index to be set." + ) + + # Write the OpenMM import statements. + + self.addToConfig("import pandas as pd") + self.addToConfig("import numpy as np") + self.addToConfig("from glob import glob") + self.addToConfig("import math") + self.addToConfig("import os") + self.addToConfig("import shutil") + self._add_config_imports() + self._add_config_monkey_patches() + self.addToConfig("\n") + # Add standard openMM config + + is_periodic = self._add_initialisation(has_box) + # Get the starting temperature and system pressure. + temperature = self._protocol.getTemperature().kelvin().value() + pressure = self._protocol.getPressure() + + is_constant_pressure = self._add_pressure_check( + pressure, temperature, is_periodic + ) + + # Add any position restraints. + if self._protocol.getRestraint() is not None: + restraint = self._protocol.getRestraint() + # Search for the atoms to restrain by keyword. + if isinstance(restraint, str): + restrained_atoms = self._system.getRestraintAtoms(restraint) + # Use the user-defined list of indices. + else: + restrained_atoms = restraint + self.addToConfig("\n# Add position restraints.") + frc = util.create_flat_bottom_restraint(restrained_atoms, force_group=5) + self.addToConfig(frc) + + # Use utils to create ATM-specific forces + # Atom force is the only window-dependent force + disp = util.createDisplacement() + self.addToConfig(disp) + self.addToConfig("\n# Add ATM Force.") + self.addToConfig( + util.createATMForce(self._protocol.get_window_index(), force_group=10) + ) + if self._protocol.getCoreAlignment(): + alignment = util.createAlignmentForce(force_group=[6, 7, 8]) + self.addToConfig("\n# Add alignment force.") + self.addToConfig(alignment) + + if self._protocol.getCOMDistanceRestraint(): + CMCM = util.createCOMRestraint(force_group=9) + self.addToConfig("\n# Add COM restraint.") + self.addToConfig(CMCM) + + # Get the integration time step from the protocol. + timestep = self._protocol.getTimeStep().picoseconds().value() + + # Set the integrator. + self.addToConfig("\n# Define the integrator.") + self.addToConfig(f"integrator = LangevinMiddleIntegrator({temperature}*kelvin,") + friction = 1 / self._protocol.getThermostatTimeConstant().picoseconds().value() + self.addToConfig(f" {friction:.5f}/picosecond,") + self.addToConfig(f" {timestep}*picoseconds)") + if self._is_seeded: + self.addToConfig(f"integrator.setRandomNumberSeed({self._seed})") + + # Add the platform information. + self._add_config_platform() + + self._add_simulation_instantiation() + + # Set initial velocities from temperature distribution. + self.addToConfig("\n# Setting initial system velocities.") + self.addToConfig( + f"simulation.context.setVelocitiesToTemperature({temperature})" + ) + + # Check for a restart file and load the simulation state. + is_restart, step = self._add_config_restart() + + # Work out the number of integration steps. + total_steps = _math.ceil( + self._protocol.getRunTime() / self._protocol.getTimeStep() + ) + + # Subtract the current number of steps. + steps = total_steps - step + + # Exit if the simulation has already finished. + if steps <= 0: + print("The simulation has already finished!") + return + + # Inform user that a restart was loaded. + self.addToConfig("\n# Print restart information.") + self.addToConfig("if is_restart:") + self.addToConfig(f" steps = {total_steps}") + self.addToConfig(" percent_complete = 100 * (step / steps)") + self.addToConfig(" print('Loaded state from an existing simulation.')") + self.addToConfig(" print(f'Simulation is {percent_complete}% complete.')") + + # Get the report and restart intervals. + report_interval = self._protocol.getReportInterval() + restart_interval = self._protocol.getRestartInterval() + + # Cap the intervals at the total number of steps. + if report_interval > steps: + report_interval = steps + if restart_interval > steps: + restart_interval = steps + + # Add the reporters. + self.addToConfig("\n# Add reporters.") + self._add_config_reporters( + state_interval=report_interval, + traj_interval=restart_interval, + is_restart=is_restart, + ) + + # Work out the total simulation time in picoseconds. + run_time = steps * timestep + + # Work out the number of cycles in 100 picosecond intervals. + cycles = _math.ceil(run_time / (report_interval * timestep)) + + # Work out the number of steps per cycle. + steps_per_cycle = int(steps / cycles) + + self.addToConfig(f"\ntemperature = {temperature}") + # reading in the directions from the protocol, find the index at which direction changes + direction = self._protocol.getDirection() + inflex = 0 + for i in range(len(direction) - 1): + if direction[i] != direction[i + 1]: + inflex = i + 1 + break + self.addToConfig( + util.createSinglePointTest( + inflex, + self._name, + atm_force_group=10, + position_restraint_force_group=5, + alignment_force_groups=[6, 7, 8], + com_force_group=9, + ) + ) diff --git a/python/BioSimSpace/Process/_atm_utils.py b/python/BioSimSpace/Process/_atm_utils.py new file mode 100644 index 000000000..12f769c6f --- /dev/null +++ b/python/BioSimSpace/Process/_atm_utils.py @@ -0,0 +1,991 @@ +###################################################################### +# BioSimSpace: Making biomolecular simulation a breeze! +# +# Copyright: 2017-2023 +# +# Authors: Lester Hedges +# +# BioSimSpace is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# BioSimSpace is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with BioSimSpace. If not, see . +##################################################################### + +__all__ = ["_ATMUtils"] + +import math as _math +import warnings as _warnings + +from .. import Protocol as _Protocol +from ..Types import Vector as _Vector +from ..Protocol._atm import _ATM + + +class _ATMUtils: + # Internal class for creating openmm forces within an ATM process. + def __init__(self, protocol): + # Check for proper typing + if not isinstance(protocol, _ATM): + raise TypeError("Protocol must be an ATM protocol") + self.protocol = protocol + self.data = self.protocol.getData() + + def getAlignmentConstants(self): + self.alignment_k_distance = self.protocol.getAlignKDistance().value() + self.alignment_k_theta = self.protocol.getAlignKTheta().value() + self.alignment_k_psi = self.protocol.getAlignKPsi().value() + + def getCMConstants(self): + self.cm_kf = self.protocol.getCOMk().value() + self.cm_tol = self.protocol.getCOMWidth().value() + + def findAbsoluteCoreIndices(self): + import numpy as np + + self.lig1_first_atomnum = self.data["first_ligand_bound_atom_index"] + self.lig1_rigid_atoms = list( + np.add(self.lig1_first_atomnum, self.data["ligand_bound_rigid_core"]) + ) + self.lig2_first_atomnum = self.data["first_ligand_free_atom_index"] + self.lig2_rigid_atoms = list( + np.add(self.lig2_first_atomnum, self.data["ligand_free_rigid_core"]) + ) + + def findAbsoluteCOMAtoms(self): + import numpy as np + + self.protein_first_atomnum = self.data["first_protein_atom_index"] + self.protein_com_atoms = list( + np.add(self.protein_first_atomnum, self.data["protein_com_atoms"]) + ) + + self.lig1_first_atomnum = self.data["first_ligand_bound_atom_index"] + self.lig1_com_atoms = list( + np.add(self.lig1_first_atomnum, self.data["ligand_bound_com_atoms"]) + ) + + self.lig2_first_atomnum = self.data["first_ligand_free_atom_index"] + self.lig2_com_atoms = list( + np.add(self.lig2_first_atomnum, self.data["ligand_free_com_atoms"]) + ) + + def getATMForceConstants(self, index=None): + self.lig1_atoms = self.getLigandBoundAtomsAsList() + self.lig2_atoms = self.getLigandFreeAtomsAsList() + self.SCUmax = self.protocol.getSoftCoreUmax().value() + self.SCU0 = self.protocol.getSoftCoreU0().value() + self.SCa = self.protocol.getSoftCoreA() + if isinstance(self.protocol, _Protocol.ATMProduction): + if index is None: + raise ValueError("Index must be set for ATMProduction protocol") + self.lambda1 = self.protocol.getLambda1()[index] + self.lambda2 = self.protocol.getLambda2()[index] + self.alpha = self.protocol.getAlpha()[index].value() + self.uh = self.protocol.getUh()[index].value() + self.w0 = self.protocol.getW0()[index].value() + self.direction = self.protocol.getDirection()[index] + self.master_lambda = self.protocol._get_lambda_values()[index] + elif isinstance( + self.protocol, (_Protocol.ATMEquilibration, _Protocol.ATMAnnealing) + ): + self.lambda1 = self.protocol.getLambda1() + self.lambda2 = self.protocol.getLambda2() + self.alpha = self.protocol.getAlpha().value() + self.uh = self.protocol.getUh().value() + self.w0 = self.protocol.getW0().value() + self.direction = self.protocol.getDirection() + + def _dump_atm_constants_to_dict(self): + """Internal function to write all ATM window-dependent constants to a dictionary (string) + to be used in sampling for analysis.""" + output = "" + output += "atm_constants = {\n" + output += " 'Lambda1': {},\n".format(self.protocol.getLambda1()) + output += " 'Lambda2': {},\n".format(self.protocol.getLambda2()) + output += " 'Alpha': {},\n".format( + [i.value() for i in self.protocol.getAlpha()] + ) + output += " 'Uh': {},\n".format([i.value() for i in self.protocol.getUh()]) + output += " 'W0': {},\n".format([i.value() for i in self.protocol.getW0()]) + output += " 'Direction': {}\n".format(self.protocol.getDirection()) + output += "}\n" + + output += "for key in atm_constants.keys():\n" + output += " if key in ['Alpha','Uh','W0']:\n" + output += " atm_constants[key] = [i for i in atm_constants[key] * kilocalories_per_mole]\n" + + return output + + def findDisplacement(self): + d = self.data["displacement"] + if isinstance(d, (list)): + if not all(isinstance(x, float) for x in d): + raise TypeError("Displacement must be a list of floats") + self.displacement = d + elif isinstance(d, _Vector): + disp = [d.x(), d.y(), d.z()] + self.displacement = disp + else: + raise TypeError("Displacement must be a list or BioSimSpace vector") + + def createDisplacement(self): + self.findDisplacement() + d = [round(x, 3) for x in self.displacement] + output = "" + output += "displacement = {}\n".format(d) + output += "#BioSimSpace output is in angstrom, divide by 10 to convert to the expected units of nm\n" + output += "displacement = [i/10.0 for i in displacement]\n" + return output + + def createSoftcorePertE(self): + """Create the softcorePertE function for the Gallicchio lab analysis""" + output = "" + output += "def softCorePertE(u, umax, ub, a):\n" + output += " usc = u\n" + output += " if u > ub:\n" + output += " gu = (u-ub)/(a*(umax-ub))\n" + output += " zeta = 1. + 2.*gu*(gu + 1.)\n" + output += " zetap = np.power( zeta, a)\n" + output += " usc = (umax-ub)*(zetap - 1.)/(zetap + 1.) + ub\n" + output += " return usc\n" + return output + + def createAlignmentForce(self, force_group=None): + """ + Create the alignment force that keeps the ligands co-planar. + + Parameters + ---------- + + force_group : None or list + Group of the force to be added to the system. If none defined then no force group will be set + (therefore it will default to 0). Only tested for single-point energies. + If a list is given the groups will be assigned in the order [distance, angle, dihedral] + """ + # This force is the same in every lambda window + self.getAlignmentConstants() + self.findAbsoluteCoreIndices() + + if force_group is not None and len(force_group) != 3: + raise ValueError("Force group must be a list of three integers") + output = "\n\n" + output += "k_distance = {} * kilocalorie_per_mole / angstrom**2\n".format( + self.alignment_k_distance + ) + output += "k_theta = {} * kilocalorie_per_mole\n".format(self.alignment_k_theta) + output += "k_psi = {} * kilocalorie_per_mole\n".format(self.alignment_k_psi) + output += "idxs_a = {}\n".format(self.lig1_rigid_atoms) + output += "idxs_b = {}\n".format(self.lig2_rigid_atoms) + output += "\n\n" + + output += 'distance_energy_fn = "0.5 * k * ((x1 - x2 - dx)^2 + (y1 - y2 - dy)^2 + (z1 - z2 - dz)^2);"\n' + output += "distance_force = CustomCompoundBondForce(2, distance_energy_fn)\n" + output += "distance_force.addPerBondParameter('k')\n" + output += "distance_force.addPerBondParameter('dx')\n" + output += "distance_force.addPerBondParameter('dy')\n" + output += "distance_force.addPerBondParameter('dz')\n" + + output += """distance_parameters = [ + k_distance.value_in_unit(kilojoules_per_mole / nanometer**2), + displacement[0]*nanometer, + displacement[1]*nanometer, + displacement[2]*nanometer, + ]\n""" + + output += ( + "distance_force.addBond((idxs_b[0], idxs_a[0]), distance_parameters)\n" + ) + if force_group is not None: + output += "distance_force.setForceGroup({})\n".format(force_group[0]) + output += "system.addForce(distance_force)\n" + output += "\n\n" + + output += """angle_energy_fn = ( + "0.5 * k * (1 - cos_theta);" + "" + "cos_theta = (dx_1 * dx_2 + dy_1 * dy_2 + dz_1 * dz_2) / (norm_1 * norm_2);" + "" + "norm_1 = sqrt(dx_1^2 + dy_1^2 + dz_1^2);" + "dx_1 = x2 - x1; dy_1 = y2 - y1; dz_1 = z2 - z1;" + "" + "norm_2 = sqrt(dx_2^2 + dy_2^2 + dz_2^2);" + "dx_2 = x4 - x3; dy_2 = y4 - y3; dz_2 = z4 - z3;" + )\n""" + output += "angle_force = CustomCompoundBondForce(4, angle_energy_fn)\n" + output += 'angle_force.addPerBondParameter("k")\n' + output += """angle_force.addBond( + (idxs_b[0], idxs_b[1], idxs_a[0], idxs_a[1]), + [k_theta.value_in_unit(kilojoules_per_mole)], + )\n""" + if force_group is not None: + output += "angle_force.setForceGroup({})\n".format(force_group[1]) + output += "system.addForce(angle_force)\n\n" + + # Femto dihedral form: + # output += """dihedral_energy_fn = ( + # "0.5 * k * (1 - cos_phi);" + # "" + # "cos_phi = (v_x * w_x + v_y * w_y + v_z * w_z) / (norm_v * norm_w);" + # "" + # "norm_v = sqrt(v_x^2 + v_y^2 + v_z^2);" + # "v_x = dx_31 - dot_31 * dx_21 / norm_21;" + # "v_y = dy_31 - dot_31 * dy_21 / norm_21;" + # "v_z = dz_31 - dot_31 * dz_21 / norm_21;" + # "" + # "dot_31 = (dx_31 * dx_21 + dy_31 * dy_21 + dz_31 * dz_21) / norm_21;" + # "dx_31 = x3 - x1; dy_31 = y3 - y1; dz_31 = z3 - z1;" + # "" + # "norm_w = sqrt(w_x^2 + w_y^2 + w_z^2);" + # "w_x = dx_54 - dot_54 * dx_21 / norm_21;" + # "w_y = dy_54 - dot_54 * dy_21 / norm_21;" + # "w_z = dz_54 - dot_54 * dz_21 / norm_21;" + # "" + # "dot_54 =(dx_54 * dx_21 + dy_54 * dy_21 + dz_54 * dz_21) / norm_21;" + # "dx_54 = x5 - x4; dy_54 = y5 - y4; dz_54 = z5 - z4;" + # "" + # "norm_21 = sqrt(dx_21^2 + dy_21^2 + dz_21^2);" + # "dx_21 = x2 - x1; dy_21 = y2 - y1; dz_21 = z2 - z1;" + # )\n""" + + # Gallicchio lab dihedral form: + output += 'dihedral_energy_fn = "(k/2) * (1 - cosp) ; "\n' + output += 'dihedral_energy_fn += "cosp = xvn*xwn + yvn*ywn + zvn*zwn ; "\n' + output += 'dihedral_energy_fn += "xvn = xv/v ; yvn = yv/v; zvn = zv/v ;"\n' + output += 'dihedral_energy_fn += "v = sqrt(xv^2 + yv^2 + zv^2 ) ;"\n' + output += 'dihedral_energy_fn += "xv = xd0 - dot01*xdn1 ;"\n' + output += 'dihedral_energy_fn += "yv = yd0 - dot01*ydn1 ;"\n' + output += 'dihedral_energy_fn += "zv = zd0 - dot01*zdn1 ;"\n' + output += 'dihedral_energy_fn += "dot01 = xd0*xdn1 + yd0*ydn1 + zd0*zdn1 ;"\n' + output += 'dihedral_energy_fn += "xd0 = x3 - x1 ;"\n' + output += 'dihedral_energy_fn += "yd0 = y3 - y1 ;"\n' + output += 'dihedral_energy_fn += "zd0 = z3 - z1 ;"\n' + output += 'dihedral_energy_fn += "xwn = xw/w ; ywn = yw/w; zwn = zw/w ;"\n' + output += 'dihedral_energy_fn += "w = sqrt(xw^2 + yw^2 + zw^2) ;"\n' + output += 'dihedral_energy_fn += "xw = xd3 - dot31*xdn1 ;"\n' + output += 'dihedral_energy_fn += "yw = yd3 - dot31*ydn1 ;"\n' + output += 'dihedral_energy_fn += "zw = zd3 - dot31*zdn1 ;"\n' + output += 'dihedral_energy_fn += "dot31 = xd3*xdn1 + yd3*ydn1 + zd3*zdn1 ;"\n' + output += 'dihedral_energy_fn += "xd3 = x5 - x4 ;"\n' + output += 'dihedral_energy_fn += "yd3 = y5 - y4 ;"\n' + output += 'dihedral_energy_fn += "zd3 = z5 - z4 ;"\n' + output += 'dihedral_energy_fn += "xdn1 = xd1/dn1 ; ydn1 = yd1/dn1 ; zdn1 = zd1/dn1 ;"\n' + output += 'dihedral_energy_fn += "dn1 = sqrt(xd1^2 + yd1^2 + zd1^2) ;"\n' + output += 'dihedral_energy_fn += "xd1 = x2 - x1 ;"\n' + output += 'dihedral_energy_fn += "yd1 = y2 - y1 ;"\n' + output += 'dihedral_energy_fn += "zd1 = z2 - z1 ;"\n' + + output += "dihedral_force = CustomCompoundBondForce(5, dihedral_energy_fn)\n" + output += 'dihedral_force.addPerBondParameter("k")\n' + output += """dihedral_force.addBond( + (idxs_b[0], idxs_b[1], idxs_b[2], idxs_a[0], idxs_a[2]), + [0.5 * k_psi.value_in_unit(kilojoules_per_mole)], + )\n""" + output += """dihedral_force.addBond( + (idxs_a[0], idxs_a[1], idxs_a[2], idxs_b[0], idxs_b[2]), + [0.5 * k_psi.value_in_unit(kilojoules_per_mole)], + )\n""" + if force_group is not None: + output += "dihedral_force.setForceGroup({})\n".format(force_group[2]) + output += "system.addForce(dihedral_force)\n\n" + return output + + def getLigandBoundAtomsAsList(self): + import numpy as np + + return list( + np.arange( + self.data["first_ligand_bound_atom_index"], + self.data["last_ligand_bound_atom_index"] + 1, + ) + ) + + def getLigandFreeAtomsAsList(self): + import numpy as np + + return list( + np.arange( + self.data["first_ligand_free_atom_index"], + self.data["last_ligand_free_atom_index"] + 1, + ) + ) + + def createATMForce( + self, + index, + force_group=None, + ): + """ + Create a string which can be added directly to an openmm script to add + an ATM force to the system. + + Parameters + ---------- + + index : int + Index of current window - used to set window-dependent variables. + + force_group : int + Group of the force to be added to the system. Shuld only be + needed when testing single-point energies. + """ + self.findDisplacement() + self.getATMForceConstants(index) + output = "" + output += "#Parameters for ATM force in original units\n" + output += "lig1_atoms = {}\n".format(self.lig1_atoms) + output += "lig2_atoms = {}\n".format(self.lig2_atoms) + if isinstance(self.protocol, _Protocol.ATMProduction): + output += "window_index = {}\n".format(index) + output += "lambda1 = {}\n".format(self.lambda1) + output += "lambda2 = {}\n".format(self.lambda2) + output += "alpha = {} * kilocalories_per_mole\n".format(self.alpha) + output += "uh = {} * kilocalories_per_mole\n".format(self.uh) + output += "w0 = {} * kilocalories_per_mole\n".format(self.w0) + output += "direction = {}\n".format(self.direction) + output += "sc_Umax = {} * kilocalories_per_mole\n".format(self.SCUmax) + output += "sc_U0 = {} * kilocalories_per_mole\n".format(self.SCU0) + output += "sc_a = {}\n".format(self.SCa) + + if isinstance(self.protocol, _Protocol.ATMProduction): + output += self._dump_atm_constants_to_dict() + + output += "\n\n #Define ATM force\n" + output += """atm_force = ATMForce( + lambda1, + lambda2, + alpha.value_in_unit(kilojoules_per_mole), + uh.value_in_unit(kilojoules_per_mole), + w0.value_in_unit(kilojoules_per_mole), + sc_Umax.value_in_unit(kilojoules_per_mole), + sc_U0.value_in_unit(kilojoules_per_mole), + sc_a, + direction, + )""" + + output += "\n\n #Add ATM force to system\n" + output += "for _ in prm.topology.atoms():\n" + output += " atm_force.addParticle(Vec3(0.0,0.0,0.0))" + output += "\n" + # TODO: add offset - check convesion of a list to a Vec3 + # Assuming that offset is the 3-vector which defines the ligand displacement + # need to convert displacement to nm + output += "for i in lig1_atoms:\n" + output += " atm_force.setParticleParameters(i, Vec3(*displacement))\n" + output += "for i in lig2_atoms:\n" + output += " atm_force.setParticleParameters(i, -Vec3(*displacement))\n" + output += "\n" + output += "nonbonded_force_id = [i for i, force in enumerate(system.getForces()) if isinstance(force, NonbondedForce)][0]\n" + output += "nonbonded = copy.deepcopy(system.getForce(nonbonded_force_id))\n" + output += "system.removeForce(nonbonded_force_id)\n" + output += "atm_force.addForce(nonbonded)\n" + if force_group is not None: + output += "atm_force.setForceGroup({})\n".format(force_group) + output += "system.addForce(atm_force)\n" + return output + + def createCOMRestraint(self, force_group=None): + """ + Create a string containing the CM-CM restriants for two groups of atoms. + In most cases these will be some combination of protein and ligand atoms. + Constants for the force are set in the protocol. + + Parameters + ---------- + + force_group : None or int + Group of the force to be added to the system. If None defined then + no force group will be set (therefore it will default to 0). + Only tested for single-point energies. + """ + + self.findAbsoluteCOMAtoms() + # Groups contained within the constraint + protein_com = self.protein_com_atoms + lig1_com = self.lig1_com_atoms + lig2_com = self.lig2_com_atoms + self.getCMConstants() + # Constants for the force + kf_cm = self.cm_kf + tol_cm = self.cm_tol + output = "" + output += "protein_com = {}\n".format(protein_com) + output += "lig1_com = {}\n".format(lig1_com) + output += "lig2_com = {}\n".format(lig2_com) + output += "# Constants for the CM-CM force in their input units\n" + output += "kfcm = {} * kilocalorie_per_mole / angstrom**2\n".format(kf_cm) + output += "tolcm = {} * angstrom \n".format(tol_cm) + + # Add expression for cm restraint + output += 'expr = "0.5 * kfcm * step(dist - tolcm) * (dist - tolcm)^2;""dist = sqrt((x1 - x2 - offx)^2 + (y1 - y2 - offy)^2 + (z1 - z2 - offz)^2);"\n' + output += "force_CMCM = CustomCentroidBondForce(2, expr)\n" + output += "force_CMCM.addPerBondParameter('kfcm')\n" + output += "force_CMCM.addPerBondParameter('tolcm')\n" + output += "force_CMCM.addPerBondParameter('offx')\n" + output += "force_CMCM.addPerBondParameter('offy')\n" + output += "force_CMCM.addPerBondParameter('offz')\n" + output += "force_CMCM.addGroup(protein_com)\n" + output += "force_CMCM.addGroup(lig1_com)\n" + + output += """parameters_bound = ( + kfcm.value_in_unit(kilojoules_per_mole / nanometer**2), + tolcm.value_in_unit(nanometer), + 0.0 * nanometer, + 0.0 * nanometer, + 0.0 * nanometer, + )\n""" + output += "force_CMCM.addBond((1,0), parameters_bound)\n" + output += "numgroups = force_CMCM.getNumGroups()\n" + + output += "force_CMCM.addGroup(protein_com)\n" + output += "force_CMCM.addGroup(lig2_com)\n" + output += """parameters_free = ( + kfcm.value_in_unit(kilojoules_per_mole / nanometer**2), + tolcm.value_in_unit(nanometer), + displacement[0] * nanometer, + displacement[1] * nanometer, + displacement[2] * nanometer, + )\n""" + + output += "force_CMCM.addBond((numgroups+1,numgroups+0), parameters_free)\n" + if force_group is not None: + if not isinstance(force_group, int): + raise TypeError("Force group must be an integer") + output += "force_CMCM.setForceGroup({})\n".format(force_group) + output += "system.addForce(force_CMCM)\n" + output += "#End of CM-CM force\n\n" + + return output + + def create_flat_bottom_restraint(self, restrained_atoms, force_group=None): + """Flat bottom restraint for atom-compatible position restraints + + Parameters + ---------- + + restrained_atoms : list + List of atom indices to be restrained. Need to be explicitly given + due to the ability to parse strings in the protocol. + + force_group : None or int + Group of the force to be added to the system. If none defined then no + force group will be set (therefore it will default to 0). Only tested + for single-point energies. + """ + # Still using the position restraint mixin, get the values of the relevant constants + pos_const = self.protocol.getForceConstant().value() + pos_width = self.protocol.getPosRestWidth().value() + output = "" + output += "fc = {} * kilocalorie_per_mole / angstrom**2\n".format(pos_const) + output += "tol = {} * angstrom\n".format(pos_width) + output += "restrained_atoms = {}\n".format(restrained_atoms) + output += "positions = prm.positions\n" + output += 'posrestforce = CustomExternalForce("0.5*fc*select(step(dist-tol), (dist-tol)^2, 0); dist = periodicdistance(x,y,z,x0,y0,z0)")\n' + + output += 'posrestforce.addPerParticleParameter("x0")\n' + output += 'posrestforce.addPerParticleParameter("y0")\n' + output += 'posrestforce.addPerParticleParameter("z0")\n' + output += 'posrestforce.addPerParticleParameter("fc")\n' + output += 'posrestforce.addPerParticleParameter("tol")\n' + + output += "for i in restrained_atoms:\n" + output += " x1 = positions[i][0].value_in_unit_system(openmm.unit.md_unit_system)\n" + output += " y1 = positions[i][1].value_in_unit_system(openmm.unit.md_unit_system)\n" + output += " z1 = positions[i][2].value_in_unit_system(openmm.unit.md_unit_system)\n" + output += " fc1 = fc.value_in_unit(kilojoules_per_mole / nanometer**2)\n" + output += " tol1 = tol.value_in_unit(nanometer)\n" + output += " posrestforce.addParticle(i, [x1, y1, z1, fc1, tol1])\n" + if force_group is not None: + if not isinstance(force_group, int): + raise TypeError("Force group must be an integer") + output += "posrestforce.setForceGroup({})\n".format(force_group) + output += "system.addForce(posrestforce)\n" + return output + + def createAnnealingProtocol(self): + """ + Create a string which can be added directly to an openmm script to add an annealing protocol to the system. + """ + anneal_runtime = self.protocol.getRunTime() + num_cycles = self.protocol.getAnnealNumCycles() + cycle_numsteps = int( + (anneal_runtime / num_cycles) / self.protocol.getTimeStep() + ) + + prot = self.protocol.getAnnealValues() + # Find all entries whose keys contain "start" and create a dictionary of these entries + # Also remove the word "start" from the key + start = {k.replace("_start", ""): v for k, v in prot.items() if "start" in k} + # Same for "end" + end = {k.replace("_end", ""): v for k, v in prot.items() if "end" in k} + # write protocol to output in dictionary format + output = "" + output += f"values_start = {start}\n" + output += f"values_end = {end}\n" + output += "increments = {\n" + output += f" key: (values_end[key] - values_start[key]) / {num_cycles}\n" + output += " for key in values_start.keys()\n" + output += "}\n" + # First set all values using the start values + output += "for key in values_start.keys():\n" + output += " simulation.context.setParameter(key, values_start[key])\n" + # Now perform the annealing in cycles + output += f"for i in range({int(num_cycles)}):\n" + output += f" simulation.step({cycle_numsteps})\n" + output += " print(f'Cycle {i+1}')\n" + output += " state = simulation.context.getState(getPositions=True, getVelocities=True)\n" + output += " for key in values_start.keys():\n" + output += " simulation.context.setParameter(key, simulation.context.getParameter(key) + increments[key])\n" + output += "simulation.saveState('openmm.xml')" + return output + + def createRestartLogic(self, total_cycles, steps_per_cycle): + # Creates the logic to calculate, at run time, the number of cycles that need to be run + # based on the number of steps that have already been run + output = "" + output += f"total_required_cycles = {total_cycles}\n" + output += "if not is_restart:\n" + output += " steps_so_far = 0\n" + output += " numcycles = total_required_cycles\n" + output += "else:\n" + output += " steps_so_far = step\n" + output += f" cycles_so_far = steps_so_far / {steps_per_cycle}\n" + output += " numcycles = int(total_required_cycles - cycles_so_far)\n" + return output + + def createLoopWithReporting( + self, + name, + steps_per_cycle, + report_interval, + timestep, + inflex_point, + ): + """Creates the loop in which simulations are run, stopping each cycle + to report the potential energies required for MBAR analysis. + + Parameters + ---------- + + cycles : int + Number of cycles to run the simulation for. + + steps_per_cycle : int + Number of steps to run the simulation for in each cycle. + + report_interval : int (in ps) + Interval at which to report the potential energies. + + timestep : float (in ps) + Timestep used in the simulation. + + steps : int + Total number of steps that have been performed so far (Default 0). + + inflex_point : int + The index at which the protocol changes direction. Potentials only need to be calculated for each half of the protocol. + """ + output = "" + output += "# Reporting for MBAR:\n" + # round master lambda to 4 d.p. to avoid floating point errors + output += f"master_lambda_list = {[round(i,4) for i in self.protocol._get_lambda_values()]}\n" + output += f"master_lambda = master_lambda_list[window_index]\n" + + output += "if is_restart:\n" + output += " try:\n" + output += " MBAR_df = pd.read_csv(f'energies_{master_lambda}.csv')\n" + output += " energies = MBAR_df.to_dict('list')\n" + output += " energies = {float(k) if k.replace('.', '').isdigit() else k: v for k, v in energies.items()}\n" + output += " except FileNotFoundError:\n" + output += " raise FileNotFoundError('MBAR data not found, unable to restart')\n" + output += "else:\n" + output += " energies = {}\n" + output += " energies['time'] = []\n" + output += " energies['fep-lambda'] = []\n" + output += " energies['temperature'] = []\n" + output += " for i in master_lambda_list:\n" + output += " energies[i] = []\n" + output += f"\n# Run the simulation in cycles, with each cycle having {report_interval} steps.\n" + output += "# Timestep in ps\n" + output += f"timestep = {timestep}\n" + output += f"inflex_point = {inflex_point}\n" + output += f"for x in range(0, numcycles):\n" + output += f" simulation.step({steps_per_cycle})\n" + output += f" steps_so_far += {steps_per_cycle}\n" + output += " time = steps_so_far * timestep\n" + output += " energies['time'].append(time)\n" + output += " energies['fep-lambda'].append(master_lambda)\n" + output += " energies['temperature'].append(integrator.getTemperature().value_in_unit(kelvin))\n" + output += " #now loop over all simulate lambda values, set the values in the context, and calculate potential energy\n" + output += " # do the first half of master lambda if direction == 1\n" + output += " if direction == 1:\n" + output += ( + " for ind, lam in enumerate(master_lambda_list[:inflex_point]):\n" + ) + output += " for key in atm_constants.keys():\n" + output += " if key in ['Alpha','Uh','W0']:\n" + output += " simulation.context.setParameter(key, atm_constants[key][ind].value_in_unit(kilojoules_per_mole))\n" + output += " else:\n" + output += " simulation.context.setParameter(key, atm_constants[key][ind])\n" + output += " state = simulation.context.getState(getEnergy=True)\n" + output += " energies[lam].append(state.getPotentialEnergy().value_in_unit(kilocalories_per_mole))\n" + output += " #fill the rest of the dictionary with NaNs\n" + output += " for lam in master_lambda_list[inflex_point:]:\n" + output += " energies[lam].append(float('nan'))\n" + output += " # do the second half of master lambda if direction == -1\n" + output += " else:\n" + output += " #fill the first half of the dictionary with NaNs\n" + output += " for lam in master_lambda_list[:inflex_point]:\n" + output += " energies[lam].append(float('nan'))\n" + output += ( + " for ind, lam in enumerate(master_lambda_list[inflex_point:]):\n" + ) + output += " for key in atm_constants.keys():\n" + output += " if key in ['Alpha','Uh','W0']:\n" + output += " simulation.context.setParameter(key, atm_constants[key][ind+inflex_point].value_in_unit(kilojoules_per_mole))\n" + output += " else:\n" + output += " simulation.context.setParameter(key, atm_constants[key][ind+inflex_point])\n" + output += " state = simulation.context.getState(getEnergy=True)\n" + output += " energies[lam].append(state.getPotentialEnergy().value_in_unit(kilocalories_per_mole))\n" + output += ( + " #Now reset lambda-dependent values back to their original state\n" + ) + output += " simulation.context.setParameter('Lambda1',lambda1)\n" + output += " simulation.context.setParameter('Lambda2',lambda2)\n" + output += " simulation.context.setParameter('Alpha',alpha)\n" + output += " simulation.context.setParameter('Uh',uh)\n" + output += " simulation.context.setParameter('W0',w0)\n" + output += " simulation.context.setParameter('Direction',direction)\n" + output += f" simulation.saveState('{name}.xml')\n" + output += " #now dump data to a csv file\n" + output += f" df = pd.DataFrame(energies)\n" + output += " df.set_index(['time', 'fep-lambda'], inplace=True)\n" + output += " df.to_csv(f'energies_{master_lambda}.csv')\n" + output += "#Dump final data to csv file\n" + output += "df = pd.DataFrame(energies)\n" + output += "df.set_index(['time', 'fep-lambda'], inplace=True)\n" + output += "df.to_csv(f'energies_{master_lambda}.csv')\n" + return output + + def createSoftcorePertELoop( + self, + name, + steps_per_cycle, + report_interval, + timestep, + ): + """Recreation of Gallicchio lab analysis - currently uses {cycles} to define sampling frequency""" + output = "" + output += f"\n# Run the simulation in cycles, with each cycle having {report_interval} steps.\n" + output += "# Timestep in ps\n" + output += f"timestep = {timestep}\n" + output += "\n" + output += "#Create dictionary for storing results in the same manner as the Gallicchio lab code\n" + # Logic for restarting simulations + output += "#Reporting for UWHAM:\n" + output += "if is_restart:\n" + # first UWHAM + output += " try:\n" + output += f" UWHAM_df = pd.read_csv('{name}.csv')\n" + output += " result = UWHAM_df.to_dict('list')\n" + output += " except FileNotFoundError:\n" + output += " raise FileNotFoundError('UWHAM data not found, unable to restart')\n" + output += "else:\n" + output += " result = {}\n" + output += " result['window'] = []\n" + output += " result['temperature'] = []\n" + output += " result['direction'] = []\n" + output += " result['lambda1'] = []\n" + output += " result['lambda2'] = []\n" + output += " result['alpha'] = []\n" + output += " result['uh'] = []\n" + output += " result['w0'] = []\n" + output += " result['pot_en'] = []\n" + output += " result['pert_en'] = []\n" + output += " result['metad_offset'] = []\n" + + output += f"for x in range(0, numcycles):\n" + output += f" simulation.step({steps_per_cycle})\n" + output += ( + " state = simulation.context.getState(getEnergy = True, groups = -1)\n" + ) + output += " pot_energy = state.getPotentialEnergy()\n" + output += " (u1, u0, alchemicalEBias) = atm_force.getPerturbationEnergy(simulation.context)\n" + output += " umcore = simulation.context.getParameter(atm_force.Umax())* kilojoules_per_mole\n" + output += " ubcore = simulation.context.getParameter(atm_force.Ubcore())* kilojoules_per_mole\n" + output += " acore = simulation.context.getParameter(atm_force.Acore())\n" + output += " uoffset = 0.0 * kilojoules_per_mole\n" + output += ( + " direction = simulation.context.getParameter(atm_force.Direction())\n" + ) + output += " if direction > 0:\n" + output += ( + " pert_e = softCorePertE(u1-(u0+uoffset), umcore, ubcore, acore)\n" + ) + output += " else:\n" + output += ( + " pert_e = softCorePertE(u0-(u1+uoffset), umcore, ubcore, acore)\n" + ) + output += " result['window'].append(window_index)\n" + output += " result['temperature'].append(integrator.getTemperature().value_in_unit(kelvin))\n" + output += " result['direction'].append(direction)\n" + output += " result['lambda1'].append(lambda1)\n" + output += " result['lambda2'].append(lambda2)\n" + output += ( + " result['alpha'].append(alpha.value_in_unit(kilocalories_per_mole))\n" + ) + output += " result['uh'].append(uh.value_in_unit(kilocalories_per_mole))\n" + output += " result['w0'].append(w0.value_in_unit(kilocalories_per_mole))\n" + output += " result['pot_en'].append(pot_energy.value_in_unit(kilocalories_per_mole))\n" + output += " result['pert_en'].append(pert_e.value_in_unit(kilocalories_per_mole))\n" + output += " result['metad_offset'].append(0.0)\n" + output += " #save the state of the simulation\n" + output += f" simulation.saveState('{name}.xml')\n" + output += " #now dump data to a csv file\n" + output += " df = pd.DataFrame(result)\n" + output += " df.set_index('window', inplace=True)\n" + output += f" df.to_csv(f'{name}.csv')\n" + + output += "#now convert the final dictionary to a pandas dataframe\n" + output += "df = pd.DataFrame(result)\n" + output += "df.set_index('window', inplace= True)\n" + output += f"df.to_csv('{name}.csv')\n" + return output + + def createReportingBoth( + self, + name, + steps_per_cycle, + timestep, + inflex_point, + ): + output = "" + output += "# Timestep in ps\n" + output += f"timestep = {timestep}\n" + output += "\n" + # Logic for restarting simulations + output += "#Reporting for UWHAM:\n" + output += "if is_restart:\n" + # first UWHAM + output += " try:\n" + output += f" UWHAM_df = pd.read_csv('{name}.csv')\n" + output += " result = UWHAM_df.to_dict('list')\n" + output += " except FileNotFoundError:\n" + output += " raise FileNotFoundError('UWHAM data not found, unable to restart')\n" + output += "else:\n" + output += " result = {}\n" + output += " result['window'] = []\n" + output += " result['temperature'] = []\n" + output += " result['direction'] = []\n" + output += " result['lambda1'] = []\n" + output += " result['lambda2'] = []\n" + output += " result['alpha'] = []\n" + output += " result['uh'] = []\n" + output += " result['w0'] = []\n" + output += " result['pot_en'] = []\n" + output += " result['pert_en'] = []\n" + output += " result['metad_offset'] = []\n" + + output += "# Reporting for MBAR:\n" + # round master lambda to 4 d.p. to avoid floating point errors + output += f"master_lambda_list = {[round(i,4) for i in self.protocol._get_lambda_values()]}\n" + output += f"master_lambda = master_lambda_list[window_index]\n" + output += "if is_restart:\n" + output += " try:\n" + output += " MBAR_df = pd.read_csv(f'energies_{master_lambda}.csv')\n" + output += " energies = MBAR_df.to_dict('list')\n" + output += " energies = {float(k) if k.replace('.', '').isdigit() else k: v for k, v in energies.items()}\n" + output += " except FileNotFoundError:\n" + output += " raise FileNotFoundError('MBAR data not found, unable to restart')\n" + output += "else:\n" + output += " energies = {}\n" + output += " energies['time'] = []\n" + output += " energies['fep-lambda'] = []\n" + output += " energies['temperature'] = []\n" + output += " for i in master_lambda_list:\n" + output += " energies[i] = []\n" + + output += f"inflex_point = {inflex_point}\n" + + output += "# Now run the simulation.\n" + output += f"for x in range(0, numcycles):\n" + output += f" simulation.step({steps_per_cycle})\n" + output += f" steps_so_far += {steps_per_cycle}\n" + output += " time = steps_so_far * timestep\n" + output += ( + " state = simulation.context.getState(getEnergy = True, groups = -1)\n" + ) + output += " pot_energy = state.getPotentialEnergy()\n" + output += " (u1, u0, alchemicalEBias) = atm_force.getPerturbationEnergy(simulation.context)\n" + output += " umcore = simulation.context.getParameter(atm_force.Umax())* kilojoules_per_mole\n" + output += " ubcore = simulation.context.getParameter(atm_force.Ubcore())* kilojoules_per_mole\n" + output += " acore = simulation.context.getParameter(atm_force.Acore())\n" + output += " uoffset = 0.0 * kilojoules_per_mole\n" + output += ( + " direction = simulation.context.getParameter(atm_force.Direction())\n" + ) + output += " if direction > 0:\n" + output += ( + " pert_e = softCorePertE(u1-(u0+uoffset), umcore, ubcore, acore)\n" + ) + output += " else:\n" + output += ( + " pert_e = softCorePertE(u0-(u1+uoffset), umcore, ubcore, acore)\n" + ) + output += " result['window'].append(window_index)\n" + output += " result['temperature'].append(integrator.getTemperature().value_in_unit(kelvin))\n" + output += " result['direction'].append(direction)\n" + output += " result['lambda1'].append(lambda1)\n" + output += " result['lambda2'].append(lambda2)\n" + output += ( + " result['alpha'].append(alpha.value_in_unit(kilocalories_per_mole))\n" + ) + output += " result['uh'].append(uh.value_in_unit(kilocalories_per_mole))\n" + output += " result['w0'].append(w0.value_in_unit(kilocalories_per_mole))\n" + output += " result['pot_en'].append(pot_energy.value_in_unit(kilocalories_per_mole))\n" + output += " result['pert_en'].append(pert_e.value_in_unit(kilocalories_per_mole))\n" + output += " result['metad_offset'].append(0.0)\n" + output += " energies['time'].append(time)\n" + output += " energies['fep-lambda'].append(master_lambda)\n" + output += " energies['temperature'].append(integrator.getTemperature().value_in_unit(kelvin))\n" + output += " #now loop over all simulate lambda values, set the values in the context, and calculate potential energy\n" + output += " # do the first half of master lambda if direction == 1\n" + output += " if direction == 1:\n" + output += ( + " for ind, lam in enumerate(master_lambda_list[:inflex_point]):\n" + ) + output += " for key in atm_constants.keys():\n" + output += " if key in ['Alpha','Uh','W0']:\n" + output += " simulation.context.setParameter(key, atm_constants[key][ind].value_in_unit(kilojoules_per_mole))\n" + output += " else:\n" + output += " simulation.context.setParameter(key, atm_constants[key][ind])\n" + output += " state = simulation.context.getState(getEnergy=True)\n" + output += " energies[lam].append(state.getPotentialEnergy().value_in_unit(kilocalories_per_mole))\n" + output += " #fill the rest of the dictionary with NaNs\n" + output += " for lam in master_lambda_list[inflex_point:]:\n" + output += " energies[lam].append(float('nan'))\n" + output += " # do the second half of master lambda if direction == -1\n" + output += " else:\n" + output += " #fill the first half of the dictionary with NaNs\n" + output += " for lam in master_lambda_list[:inflex_point]:\n" + output += " energies[lam].append(float('nan'))\n" + output += ( + " for ind, lam in enumerate(master_lambda_list[inflex_point:]):\n" + ) + output += " for key in atm_constants.keys():\n" + output += " if key in ['Alpha','Uh','W0']:\n" + output += " simulation.context.setParameter(key, atm_constants[key][ind+inflex_point].value_in_unit(kilojoules_per_mole))\n" + output += " else:\n" + output += " simulation.context.setParameter(key, atm_constants[key][ind+inflex_point])\n" + output += " state = simulation.context.getState(getEnergy=True)\n" + output += " energies[lam].append(state.getPotentialEnergy().value_in_unit(kilocalories_per_mole))\n" + output += ( + " #Now reset lambda-dependent values back to their original state\n" + ) + output += " simulation.context.setParameter('Lambda1',lambda1)\n" + output += " simulation.context.setParameter('Lambda2',lambda2)\n" + output += " simulation.context.setParameter('Alpha',alpha)\n" + output += " simulation.context.setParameter('Uh',uh)\n" + output += " simulation.context.setParameter('W0',w0)\n" + output += " simulation.context.setParameter('Direction',direction)\n" + output += " #save the state of the simulation\n" + output += f" simulation.saveState('{name}.xml')\n" + output += " #now dump UWHAM data to a csv file\n" + output += " df = pd.DataFrame(result)\n" + output += " df.set_index('window', inplace=True)\n" + output += f" df.to_csv(f'{name}.csv')\n" + output += " #now dump MBAR data to a csv file\n" + output += " df = pd.DataFrame(energies)\n" + output += " df.set_index(['time', 'fep-lambda'], inplace=True)\n" + output += " df.to_csv(f'energies_{master_lambda}.csv')\n" + + output += "#now convert the UWHAM dictionary to a pandas dataframe\n" + output += "df = pd.DataFrame(result)\n" + output += "df.set_index('window', inplace= True)\n" + output += f"df.to_csv('{name}.csv')\n" + + output += "# same for MBAR\n" + output += "df = pd.DataFrame(energies)\n" + output += "df.set_index(['time', 'fep-lambda'], inplace=True)\n" + output += "df.to_csv(f'energies_{master_lambda}.csv')\n" + + return output + + def createSinglePointTest( + self, + inflex_point, + name, + atm_force_group=None, + position_restraint_force_group=None, + alignment_force_groups=None, + com_force_group=None, + ): + """Create a single point test for the ATM force""" + output = "" + output += "# Create the dictionary which will hold the energies\n" + output += f"master_lambda_list = {[round(i,4) for i in self.protocol._get_lambda_values()]}\n" + output += "energies = {}\n" + output += f"for i in master_lambda_list[:{inflex_point}]:\n" + output += " energies[i] = []\n" + # First we can check the potential of forces that are not lambda-dependent, + # this will only work if the ATMforce is in its own group + + if ( + (position_restraint_force_group is not None) + and (alignment_force_groups is not None) + and (com_force_group is not None) + ): + output += "non_lambda_forces = {}\n" + output += f"pos_state = simulation.context.getState(getEnergy=True, groups={{{position_restraint_force_group}}})\n" + output += "non_lambda_forces['position_restraint'] = pos_state.getPotentialEnergy().value_in_unit(kilojoules_per_mole)\n" + output += f"alignment_force_groups = {alignment_force_groups}\n" + output += "for counter,group in enumerate(alignment_force_groups):\n" + output += " if counter == 0:\n" + output += " name='distance'\n" + output += " elif counter == 1:\n" + output += " name='angle'\n" + output += " elif counter == 2:\n" + output += " name='dihedral'\n" + output += " align_state = simulation.context.getState(getEnergy=True, groups={group})\n" + output += " non_lambda_forces[name] = align_state.getPotentialEnergy().value_in_unit(kilojoules_per_mole)\n" + output += f"com_state = simulation.context.getState(getEnergy=True, groups={{{com_force_group}}})\n" + output += "non_lambda_forces['com'] = com_state.getPotentialEnergy().value_in_unit(kilojoules_per_mole)\n" + # now save as a dataframe + output += "df = pd.DataFrame(non_lambda_forces,index=[0])\n" + output += "df.to_csv(f'non_lambda_forces.csv')\n" + + output += "#now loop over all simulate lambda values, set the values in the context, and calculate potential energy\n" + output += f"for ind, lam in enumerate(master_lambda_list[:{inflex_point}]):\n" + output += " for key in atm_constants.keys():\n" + output += " if key in ['Alpha','Uh','W0']:\n" + output += " simulation.context.setParameter(key, atm_constants[key][ind].value_in_unit(kilojoules_per_mole))\n" + output += " else:\n" + output += " simulation.context.setParameter(key, atm_constants[key][ind])\n" + if atm_force_group is None: + output += " state = simulation.context.getState(getEnergy=True)\n" + else: + group_placeholder = f"{atm_force_group}" + output += f" state = simulation.context.getState(getEnergy=True, groups={{{group_placeholder}}})\n" + output += " energies[lam].append(state.getPotentialEnergy().value_in_unit(kilojoules_per_mole))\n" + output += ( + " #Now reset lambda-dependent values back to their original state\n" + ) + output += " simulation.context.setParameter('Lambda1',lambda1)\n" + output += " simulation.context.setParameter('Lambda2',lambda2)\n" + output += " simulation.context.setParameter('Alpha',alpha)\n" + output += " simulation.context.setParameter('Uh',uh)\n" + output += " simulation.context.setParameter('W0',w0)\n" + output += " simulation.context.setParameter('Direction',direction)\n" + output += "#now convert the dictionary to a pandas dataframe, with both time and fep-lambda as index columns\n" + output += "df = pd.DataFrame(energies)\n" + output += "df.to_csv(f'energies_singlepoint.csv')\n" + output += "simulation.step(1)\n" + output += f"simulation.saveState('{name}.xml')\n" + return output diff --git a/python/BioSimSpace/Process/_openmm.py b/python/BioSimSpace/Process/_openmm.py index b412cb275..327c8e2cb 100644 --- a/python/BioSimSpace/Process/_openmm.py +++ b/python/BioSimSpace/Process/_openmm.py @@ -49,7 +49,6 @@ from ..Metadynamics import CollectiveVariable as _CollectiveVariable from ..Protocol._position_restraint_mixin import _PositionRestraintMixin from ..Types._type import Type as _Type - from .. import IO as _IO from .. import Protocol as _Protocol from .. import Trajectory as _Trajectory @@ -68,6 +67,29 @@ class OpenMM(_process.Process): # Dictionary of platforms and their OpenMM keyword. _platforms = {"CPU": "CPU", "CUDA": "CUDA", "OPENCL": "OpenCL"} + # Special cases for generate config when using ATM protocols. + def __new__( + cls, + system=None, + protocol=None, + reference_system=None, + exe=None, + name="openmm", + platform="CPU", + work_dir=None, + seed=None, + property_map={}, + **kwargs, + ): + from ._atm import OpenMMATM + from ..Protocol._atm import _ATM + + # would like to use issubclass but _Protocol._ATM is not exposed + if isinstance(protocol, _ATM): + return super().__new__(OpenMMATM) + else: + return super().__new__(cls) + def __init__( self, system, @@ -1319,7 +1341,9 @@ def getSystem(self, block="AUTO"): # Try to get the most recent trajectory frame. try: # Handle minimisation protocols separately. - if isinstance(self._protocol, _Protocol.Minimisation): + if isinstance( + self._protocol, (_Protocol.Minimisation, _Protocol.ATMMinimisation) + ): # Do we need to get coordinates for the lambda=1 state. if "is_lambda1" in self._property_map: is_lambda1 = True @@ -1372,22 +1396,18 @@ def getSystem(self, block="AUTO"): (self._protocol.getRunTime() / self._protocol.getTimeStep()) / self._protocol.getRestartInterval() ) - # Work out the fraction of the simulation that has been completed. frac_complete = self._protocol.getRunTime() / self.getTime() - # Make sure the fraction doesn't exceed one. OpenMM can report # time values that are larger than the number of integration steps # multiplied by the time step. if frac_complete > 1: frac_complete = 1 - # Work out the trajectory frame index, rounding down. # Remember that frames in MDTraj are zero indexed, like Python. index = int(frac_complete * num_frames) if index > 0: index -= 1 - # Return the most recent frame. return self.getFrame(index) @@ -1468,7 +1488,6 @@ def getFrame(self, index): if not type(index) is int: raise TypeError("'index' must be of type 'int'") - max_index = ( int( (self._protocol.getRunTime() / self._protocol.getTimeStep()) @@ -1476,7 +1495,6 @@ def getFrame(self, index): ) - 1 ) - if index < 0 or index > max_index: raise ValueError(f"'index' must be in range [0, {max_index}].") @@ -1496,7 +1514,6 @@ def getFrame(self, index): # Get the latest trajectory frame. new_system = _Trajectory.getFrame(self._traj_file, self._top_file, index) - # Update the coordinates and velocities and return a mapping between # the molecule indices in the two systems. sire_system, mapping = _SireIO.updateCoordinatesAndVelocities( @@ -2120,7 +2137,9 @@ def _add_config_reporters( ) # Disable specific state information for minimisation protocols. - if isinstance(self._protocol, _Protocol.Minimisation): + if isinstance( + self._protocol, (_Protocol.Minimisation, _Protocol.ATMMinimisation) + ): is_step = False is_time = False is_temperature = False @@ -2130,7 +2149,9 @@ def _add_config_reporters( is_temperature = True # Work out the total number of steps. - if isinstance(self._protocol, _Protocol.Minimisation): + if isinstance( + self._protocol, (_Protocol.Minimisation, _Protocol.ATMMinimisation) + ): total_steps = 1 else: total_steps = _math.ceil( diff --git a/python/BioSimSpace/Protocol/__init__.py b/python/BioSimSpace/Protocol/__init__.py index aa4a63322..e78c9f146 100644 --- a/python/BioSimSpace/Protocol/__init__.py +++ b/python/BioSimSpace/Protocol/__init__.py @@ -47,6 +47,10 @@ Metadynamics Steering Custom + ATMMinimisation + ATMEquilibration + ATMAnnealing + ATMProduction Examples ======== @@ -105,3 +109,4 @@ from ._production import * from ._steering import * from ._utils import * +from ._atm import * diff --git a/python/BioSimSpace/Protocol/_atm.py b/python/BioSimSpace/Protocol/_atm.py new file mode 100644 index 000000000..d231492eb --- /dev/null +++ b/python/BioSimSpace/Protocol/_atm.py @@ -0,0 +1,3353 @@ +###################################################################### +# BioSimSpace: Making biomolecular simulation a breeze! +# +# Copyright: 2017-2024 +# +# Authors: Lester Hedges +# Matthew Burman +# +# BioSimSpace is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# BioSimSpace is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with BioSimSpace. If not, see . +##################################################################### + +import json as _json +import math as _math +import numpy as _np +import warnings as _warnings + +from .._SireWrappers import System as _System +from .. import Types as _Types +from ._protocol import Protocol as _Protocol +from ._position_restraint_mixin import _PositionRestraintMixin +from .. import Units as _Units +from ..Types import Vector as _Vector + +__all__ = ["ATMMinimisation", "ATMEquilibration", "ATMAnnealing", "ATMProduction"] + + +# When placed in to BSS this needs to be ATM_protocol(protocol): +class _ATM(_Protocol, _PositionRestraintMixin): + def __init__( + self, + system=None, + data=None, + core_alignment=True, + align_k_distance=2.5 * _Units.Energy.kcal_per_mol / _Units.Area.angstrom2, + align_k_theta=10.0 * _Units.Energy.kcal_per_mol, + align_k_psi=10.0 * _Units.Energy.kcal_per_mol, + com_distance_restraint=True, + com_k=25.0 * _Units.Energy.kcal_per_mol / _Units.Area.angstrom2, + com_restraint_width=5.0 * _Units.Length.angstrom, + restraint=None, + force_constant=10 * _Units.Energy.kcal_per_mol / _Units.Area.angstrom2, + positional_restraint_width=0.5 * _Units.Length.angstrom, + soft_core_umax=1000.0 * _Units.Energy.kcal_per_mol, + soft_core_u0=500.0 * _Units.Energy.kcal_per_mol, + soft_core_a=0.0625, + ): + # Call the base class constructor. + super().__init__() + + # first check that EITHER system or data is passed + if system is None and data is None: + raise ValueError( + "Either 'system' or 'data' must be passed to the ATM protocol." + ) + + if system is not None and not isinstance(system, _System): + raise TypeError("'system' must be of type 'BioSimSpace.System'") + + if data is not None and not isinstance(data, dict): + raise TypeError("'data' must be of type 'dict'") + + if isinstance(system, _System) and data is None: + try: + sdata = _json.loads(system._sire_object.property("atom_data").value()) + except Exception as e: + raise ValueError( + f"Unable to extract ATM data from the system object. The following error was raised: {e}." + ) + # convert the "displacement" key back to a vector + d = sdata["displacement"] + displacement = _Vector(*d) + sdata["displacement"] = displacement + self._system_data = sdata + + elif system is not None and data is not None: + _warnings.warn( + "Both 'system' and 'data' were passed. Using 'data' and ignoring data from 'system'." + ) + + # Store the ATM system. + if isinstance(data, dict): + self._system_data = data + elif data is not None: + raise TypeError("'data' must be of type 'dict'") + + # Whether or not to use alignment restraints. + self.setCoreAlignment(core_alignment) + + # Store the align_k_distance value. + self.setAlignKDistance(align_k_distance) + + # Store the align_k_theta value. + self.setAlignKTheta(align_k_theta) + + # Store the align_k_psi value. + self.setAlignKPsi(align_k_psi) + + # Whether or not to use the CMCM restraint. + self.setCOMDistanceRestraint(com_distance_restraint) + + # Store com_k value. + self.setCOMk(com_k) + + # Store com_restraint_width value. + self.setCOMWidth(com_restraint_width) + + # Store the width of the coordinate restraint. + self.setPosRestWidth(positional_restraint_width) + + # Store the soft_core_umax value. + self.setSoftCoreUmax(soft_core_umax) + + # Store the soft_core_u0 value. + self.setSoftCoreU0(soft_core_u0) + + # Store the soft_core_a value. + self.setSoftCoreA(soft_core_a) + + # Set the postition restraint. + _PositionRestraintMixin.__init__(self, restraint, force_constant) + + def __str__(self): + d = self.getData() + """Return a string representation of the protocol.""" + string = ": " + string += "timestep=%s " % self.getTimeStep() + string += ", runtime=%s " % self.getRunTime() + string += ", temperature=%s " % self.getTemperature() + if self._pressure is not None: + string += ", pressure=%s, " % self.getPressure() + string += ", lambda1=%s " % self.getLambda1() + string += ", lambda2=%s " % self.getLambda2() + string += ", ligand_bound core atoms=%s" % d["ligand_bound_rigid_core"] + string += ", ligand_free core atoms=%s" % d["ligand_free_rigid_core"] + string += ", report_interval=%s " % self.getReportInterval() + string += ", restart_interval=%s " % self.getRestartInterval() + string += ">" + + return string + + def __repr__(self): + """Return a string showing how to instantiate the object.""" + return self.__str__() + + def getData(self): + """ + Return the ATM data dictionary. + + Returns + ------- + + data : dict + The ATM data dictionary. + """ + return self._system_data + + def getCoreAlignment(self): + """ + Return core alignment boolean. + + Returns + ------- + + core_alignment : bool + Whether to use core alignment. + """ + return self._core_alignment + + def setCoreAlignment(self, core_alignment): + """ + Set the core alignment flag. + + Parameters + ---------- + + core_alignment : bool + Whether to use core alignment. + """ + if isinstance(core_alignment, bool): + self._core_alignment = core_alignment + else: + _warnings.warn("Non-boolean core alignment flag. Defaulting to True!") + self._core_alignment = True + + def getCOMDistanceRestraint(self): + """ + Return CMCM restraint boolean. + + Returns + ------- + + com_distance_restraint : bool + Whether to use the CMCM restraint. + """ + return self._com_distance_restraint + + def setCOMDistanceRestraint(self, com_distance_restraint): + """ + Set the CMCM restraint flag. + + Parameters + ---------- + + com_distance_restraint : bool + Whether to use the CMCM restraint. + """ + if isinstance(com_distance_restraint, bool): + self._com_distance_restraint = com_distance_restraint + else: + _warnings.warn( + "Non-boolean com distance restraint flag. Defaulting to True!" + ) + self._com_distance_restraint = True + + def getPosRestWidth(self): + """ + Return the width of the position restraint. + + Returns + ------- + + positional_restraint_width : :class:`Length ` + The width of the position restraint. + """ + return self._positional_restraint_width + + def setPosRestWidth(self, positional_restraint_width): + """ + Set the width of the position restraint. + + Parameters + ---------- + + positional_restraint_width : int, float, str, :class:`Length ` + The width of the position restraint. + """ + # Convert int to float. + if type(positional_restraint_width) is int: + positional_restraint_width = float(positional_restraint_width) + + if isinstance(positional_restraint_width, float): + # Use default units. + positional_restraint_width *= _Units.Length.angstrom + + else: + if isinstance(positional_restraint_width, str): + try: + positional_restraint_width = _Types.Length( + positional_restraint_width + ) + except Exception: + raise ValueError( + "Unable to parse 'positional_restraint_width' string." + ) from None + + elif not isinstance(positional_restraint_width, _Types.Length): + raise TypeError( + "'positional_restraint_width' must be of type 'BioSimSpace.Types._GeneralUnit', 'str', or 'float'." + ) + + # Validate the dimensions. + if positional_restraint_width.dimensions() != (0, 1, 0, 0, 0, 0, 0): + raise ValueError( + "'positional_restraint_width' has invalid dimensions! " + f"Expected dimensions of Length, found '{positional_restraint_width.unit()}'" + ) + self._positional_restraint_width = positional_restraint_width + + def getAlignKDistance(self): + """ + Return the align_k_distance value. + + Returns + ------- + + align_k_distance : :class:`GeneralUnit ` + The align_k_distance value in kcal/mol angstrom**2. + """ + return self._align_k_distance + + def setAlignKDistance(self, align_k_distance): + """ + Set the align_k_distance value. + + Parameters + ---------- + + align_k_distance : int, float, str, :class:`GeneralUnit `, float + Length value for the alignment restraint kcal/mol angstrom**2. + """ + # Convert int to float. + if type(align_k_distance) is int: + align_k_distance = float(align_k_distance) + + if isinstance(align_k_distance, float): + # Use default units. + align_k_distance *= _Units.Energy.kcal_per_mol / _Units.Area.angstrom2 + + else: + if isinstance(align_k_distance, str): + try: + align_k_distance = _Types._GeneralUnit(align_k_distance) + except Exception: + raise ValueError( + "Unable to parse 'align_k_distance' string." + ) from None + + elif not isinstance(align_k_distance, _Types._GeneralUnit): + raise TypeError( + "'align_k_distance' must be of type 'BioSimSpace.Types._GeneralUnit', 'str', or 'float'." + ) + + # Validate the dimensions. + if align_k_distance.dimensions() != (1, 0, -2, 0, 0, -1, 0): + raise ValueError( + "'align_k_distance' has invalid dimensions! " + f"Expected dimensions of energy density/area (e.g. kcal/molA^2), found '{align_k_distance.unit()}'" + ) + self._align_k_distance = align_k_distance + + def getAlignKTheta(self): + """ + Return the align_k_theta value. + + Returns + ------- + + align_k_theta : :class:`Energy ` + The align_k_theta value in kcal/mol. + """ + return self._align_k_theta + + def setAlignKTheta(self, align_k_theta): + """ + Set the align_k_theta value. + + Parameters + ---------- + + align_k_theta : int, float, str, :class:`Energy ` + Force constant for the alignment angular constraint in kcal/mol. + + """ + # Convert int to float. + if type(align_k_theta) is int: + align_k_theta = float(align_k_theta) + + if isinstance(align_k_theta, float): + # Use default units. + align_k_theta *= _Units.Energy.kcal_per_mol + + else: + if isinstance(align_k_theta, str): + try: + align_k_theta = _Types._GeneralUnit(align_k_theta) + except Exception: + raise ValueError( + "Unable to parse 'align_k_theta' string." + ) from None + + elif not isinstance(align_k_theta, _Types.Energy): + raise TypeError( + "'align_k_theta' must be of type 'BioSimSpace.Types._GeneralUnit', 'str', or 'float'." + ) + + # Validate the dimensions. + if align_k_theta.dimensions() != (1, 2, -2, 0, 0, -1, 0): + raise ValueError( + "'align_k_theta' has invalid dimensions! " + f"Expected dimensions of energy density (e.g. kcal/mol), found '{align_k_theta.unit()}'" + ) + self._align_k_theta = align_k_theta + + def getAlignKPsi(self): + """ + Return the align_k_psi value. + + Returns + ------- + + align_k_psi: :class:`Energy ` + The align_k_psi value in kcal/mol. + """ + return self._align_k_psi + + def setAlignKPsi(self, align_k_psi): + """ + Set the align_k_psi value. + + Parameters + ---------- + + align_k_psi : int, float, str, :class:`Energy ` + Force constant for the alignment dihedral constraint in kcal/mol. + """ + # Convert int to float. + if type(align_k_psi) is int: + align_k_psi = float(align_k_psi) + + if isinstance(align_k_psi, float): + # Use default units. + align_k_psi *= _Units.Energy.kcal_per_mol + + else: + if isinstance(align_k_psi, str): + try: + align_k_psi = _Types._GeneralUnit(align_k_psi) + except Exception: + raise ValueError("Unable to parse 'align_k_psi' string.") from None + + elif not isinstance(align_k_psi, _Types.Energy): + raise TypeError( + "'align_k_psi' must be of type 'BioSimSpace.Types._GeneralUnit', 'str', or 'float'." + ) + + # Validate the dimensions. + if align_k_psi.dimensions() != (1, 2, -2, 0, 0, -1, 0): + raise ValueError( + "'align_k_psi' has invalid dimensions! " + f"Expected dimensions of energy density (e.g. kcal/mol), found '{align_k_psi.unit()}'" + ) + self._align_k_psi = align_k_psi + + def getSoftCoreUmax(self): + """ + Return the soft_core_umax value. + + Returns + ------- + + soft_core_umax : :class:`Energy ` + The soft_core_umax value in kcal/mol. + """ + return self._soft_core_umax + + def setSoftCoreUmax(self, soft_core_umax): + """ + Set the soft_core_umax value. + + Parameters + ---------- + + soft_core_umax : int, float, str, :class:`Energy ` + The softcore Umax value in kcal/mol. + """ + # Convert int to float. + if type(soft_core_umax) is int: + soft_core_umax = float(soft_core_umax) + + if isinstance(soft_core_umax, float): + # Use default units. + soft_core_umax *= _Units.Energy.kcal_per_mol + + else: + if isinstance(soft_core_umax, str): + try: + soft_core_umax = _Types._GeneralUnit(soft_core_umax) + except Exception: + raise ValueError( + "Unable to parse 'soft_core_umax' string." + ) from None + + elif not isinstance(soft_core_umax, _Types.Energy): + raise TypeError( + "'soft_core_umax' must be of type 'BioSimSpace.Types._GeneralUnit', 'str', or 'float'." + ) + + # Validate the dimensions. + if soft_core_umax.dimensions() != (1, 2, -2, 0, 0, -1, 0): + raise ValueError( + "'align_k_theta' has invalid dimensions! " + f"Expected dimensions of energy density (e.g. kcal/mol), found '{soft_core_umax.unit()}'" + ) + self._soft_core_umax = soft_core_umax + + def getSoftCoreU0(self): + """ + Return the soft_core_u0 value. + + Returns + ------- + + soft_core_u0 : :class:`Energy ` + The soft_core_u0 value in kcal/mol. + """ + return self._soft_core_u0 + + def setSoftCoreU0(self, soft_core_u0): + """ + Set the soft_core_u0 value. + + Parameters + ---------- + + soft_core_u0 : int, float, str, :class:`Energy ` + The softcore u0 value in kcal/mol. + """ + # Convert int to float. + if type(soft_core_u0) is int: + soft_core_u0 = float(soft_core_u0) + + if isinstance(soft_core_u0, float): + # Use default units. + soft_core_u0 *= _Units.Energy.kcal_per_mol + + else: + if isinstance(soft_core_u0, str): + try: + soft_core_u0 = _Types._GeneralUnit(soft_core_u0) + except Exception: + raise ValueError("Unable to parse 'soft_core_u0' string.") from None + + elif not isinstance(soft_core_u0, _Types.Energy): + raise TypeError( + "'soft_core_u0' must be of type 'BioSimSpace.Types._GeneralUnit', 'str', or 'float'." + ) + + # Validate the dimensions. + if soft_core_u0.dimensions() != (1, 2, -2, 0, 0, -1, 0): + raise ValueError( + "'align_k_theta' has invalid dimensions! " + f"Expected dimensions of energy density (e.g. kcal/mol), found '{soft_core_u0.unit()}'" + ) + self._soft_core_u0 = soft_core_u0 + + def getSoftCoreA(self): + """ + Return the soft_core_a value. + + Returns + ------- + + soft_core_a : float + The soft_core_a value. + """ + return self._soft_core_a + + def setSoftCoreA(self, soft_core_a): + """ + Set the soft_core_a value. + + Parameters + ---------- + + soft_core_a : float + The softcore a value. + """ + if isinstance(soft_core_a, (int, float)): + self._soft_core_a = float(soft_core_a) + else: + raise TypeError("'soft_core_a' must be of type 'float'") + + def getCOMk(self): + """ + Return the com_k value. + + Returns + ------- + + com_k : :class:`GeneralUnit ` + The com_k value in kcal/mol A**2. + """ + return self._com_k + + def setCOMk(self, com_k): + """ + Set the com_k value. + + Parameters + ---------- + + com_k : int, float, str, :class:`GeneralUnit + The force constant for the CM-CM force in kcal/mol A**2. + """ + # Convert int to float. + if type(com_k) is int: + com_k = float(com_k) + + if isinstance(com_k, float): + # Use default units. + com_k *= _Units.Energy.kcal_per_mol / _Units.Area.angstrom2 + + else: + if isinstance(com_k, str): + try: + com_k = _Types._GeneralUnit(com_k) + except Exception: + raise ValueError("Unable to parse 'com_k' string.") from None + + elif not isinstance(com_k, _Types._GeneralUnit): + raise TypeError( + "'com_k' must be of type 'BioSimSpace.Types._GeneralUnit', 'str', or 'float'." + ) + + # Validate the dimensions. + if com_k.dimensions() != (1, 0, -2, 0, 0, -1, 0): + raise ValueError( + "'align_k_theta' has invalid dimensions! " + f"Expected dimensions of energy density/area (e.g. kcal/molA^2), found '{com_k.unit()}'" + ) + self._com_k = com_k + + def getCOMWidth(self): + """ + Return the com_restraint_width value. + + Returns + ------- + + com_restraint_width : :class:`Length ` + The com_restraint_width value in angstroms. + """ + return self._com_restraint_width + + def setCOMWidth(self, com_restraint_width): + """ + Set the com_restraint_width value. + + Parameters + ---------- + + com_restraint_width : int, float, str, :class:`Length + The com_restraint_width value in angstroms. + """ + # Convert int to float. + if type(com_restraint_width) is int: + com_restraint_width = float(com_restraint_width) + + if isinstance(com_restraint_width, float): + # Use default units. + com_restraint_width *= _Units.Length.angstrom + + else: + if isinstance(com_restraint_width, str): + try: + com_restraint_width = _Types.Length(com_restraint_width) + except Exception: + raise ValueError( + "Unable to parse 'com_restraint_width' string." + ) from None + + elif not isinstance(com_restraint_width, _Types.Length): + raise TypeError( + "'com_restraint_width' must be of type 'BioSimSpace.Types._GeneralUnit', 'str', or 'float'." + ) + + # Validate the dimensions. + if com_restraint_width.dimensions() != (0, 1, 0, 0, 0, 0, 0): + raise ValueError( + "'align_k_theta' has invalid dimensions! " + f"Expected dimensions of Length, found '{com_restraint_width.unit()}'" + ) + self._com_restraint_width = com_restraint_width + + +class ATMMinimisation(_ATM): + """ + Minimisation protocol for ATM simulations. + """ + + def __init__( + self, + system=None, + data=None, + steps=10000, + core_alignment=True, + com_distance_restraint=True, + restraint=None, + force_constant=10 * _Units.Energy.kcal_per_mol / _Units.Area.angstrom2, + positional_restraint_width=0.5 * _Units.Length.angstrom, + align_k_distance=2.5 * _Units.Energy.kcal_per_mol / _Units.Area.angstrom2, + align_k_theta=10 * _Units.Energy.kcal_per_mol, + align_k_psi=10 * _Units.Energy.kcal_per_mol, + soft_core_umax=1000 * _Units.Energy.kcal_per_mol, + soft_core_u0=500 * _Units.Energy.kcal_per_mol, + soft_core_a=0.0625, + com_k=25 * _Units.Energy.kcal_per_mol / _Units.Area.angstrom2, + com_restraint_width=5 * _Units.Length.angstrom, + ): + """ + Parameters + ---------- + + system : :class:`System ` + A prepared ATM system. + data : dict + The ATM data dictionary. + + core_alignment : bool + Whether to use rigid core restraints to align the two ligands. + + align_k_distance : int, float, str, :class:`GeneralUnit ` + The force constant for the distance portion of the alignment restraint (kcal/(mol A^2)). + + align_k_theta : int, float, str, :class:`Energy ` + The force constant for the angular portion of the alignment restaint (kcal/mol). + + align_k_psi : int, float, str, :class:`Energy ` + The force constant for the dihedral portion of the alignment restraint (kcal/mol). + + com_distance_restraint : bool + Whether to use a center of mass distance restraint. + This restraint applies to the protein/host and both ligands, and + is used to maintain the relative positions of all of them. + + com_k : int, float, str, :class:`GeneralUnit ` + The force constant for the center of mass distance restraint (kcal/mol/A^2). + + com_restraint_width : int, float, str, :class:`Length + The width (tolerance) of the center of mass distance restraint (A). + + restraint : str, [int] + The type of restraint to perform. This should be one of the + following options: + "backbone" + Protein backbone atoms. The matching is done by a name + template, so is unreliable on conversion between + molecular file formats. + "heavy" + All non-hydrogen atoms that aren't part of water + molecules or free ions. + "all" + All atoms that aren't part of water molecules or free + ions. + Alternatively, the user can pass a list of atom indices for + more fine-grained control. If None, then no restraints are used. + + force_constant : :class:`GeneralUnit `, float + The force constant for the restraint potential. If a 'float' is + passed, then default units of 'kcal_per_mol / angstrom**2' will + be used. + + positional_restraint_width : :class:`Length `, float + The width of the flat-bottom potential used for coordinate restraint in Angstroms. + + pos_restrained_atoms : [int] + The atoms to be restrained. + + + + soft_core_umax : int, float, str, :class:`Energy ` + The Umax value for the ATM softcore potential (kcal/mol). + + soft_core_u0 : int, float, str, :class:`Energy ` + The uh value for the ATM softcore potential (kcal/mol). + + soft_core_a : int, float, str, :class:`Energy ` + The a value for the ATM softcore potential.""" + + super().__init__( + system=system, + data=data, + core_alignment=core_alignment, + align_k_distance=align_k_distance, + align_k_theta=align_k_theta, + align_k_psi=align_k_psi, + com_distance_restraint=com_distance_restraint, + com_k=com_k, + com_restraint_width=com_restraint_width, + restraint=restraint, + force_constant=force_constant, + positional_restraint_width=positional_restraint_width, + soft_core_umax=soft_core_umax, + soft_core_u0=soft_core_u0, + soft_core_a=soft_core_a, + ) + # Store the number of minimisation steps. + self.setSteps(steps) + + def getSteps(self): + """ + Return the number of minimisation steps. + + Returns + ------- + + steps : int + The number of minimisation steps. + """ + return self._steps + + def setSteps(self, steps): + """ + Set the number of minimisation steps. + + Parameters + ---------- + + steps : int + The number of minimisation steps. + """ + if isinstance(steps, int): + self._steps = steps + else: + raise TypeError("'steps' must be of type 'int'") + + +class ATMEquilibration(_ATM): + """Equilibration protocol for ATM simulations.""" + + def __init__( + self, + system=None, + data=None, + timestep=2 * _Units.Time.femtosecond, + runtime=0.2 * _Units.Time.nanosecond, + temperature_start=300 * _Units.Temperature.kelvin, + temperature_end=300 * _Units.Temperature.kelvin, + temperature=None, + pressure=1 * _Units.Pressure.atm, + thermostat_time_constant=1 * _Units.Time.picosecond, + report_interval=100, + restart_interval=100, + core_alignment=True, + com_distance_restraint=True, + com_k=25 * _Units.Energy.kcal_per_mol / _Units.Area.angstrom2, + com_restraint_width=5 * _Units.Length.angstrom, + restraint=None, + force_constant=10 * _Units.Energy.kcal_per_mol / _Units.Area.angstrom2, + positional_restraint_width=0.5 * _Units.Length.angstrom, + align_k_distance=2.5 * _Units.Energy.kcal_per_mol / _Units.Area.angstrom2, + align_k_theta=10 * _Units.Energy.kcal_per_mol, + align_k_psi=10 * _Units.Energy.kcal_per_mol, + soft_core_umax=1000 * _Units.Energy.kcal_per_mol, + soft_core_u0=500 * _Units.Energy.kcal_per_mol, + soft_core_a=0.0625, + use_atm_force=False, + direction=1, + lambda1=0.0, + lambda2=0.0, + alpha=0.0 * _Units.Energy.kcal_per_mol, + uh=0.0 * _Units.Energy.kcal_per_mol, + W0=0.0 * _Units.Energy.kcal_per_mol, + ): + """ + Create a new equilibration protocol. + + Parameters + ---------- + + system : :class:`System `` + A prepared ATM system. + + data : dict + The ATM data dictionary. + + timestep : str, :class:`Time ` + The integration timestep. + + runtime : str, :class:`Time ` + The running time. + + temperature_start : str, :class:`Temperature ` + The starting temperature. + + temperature_end : str, :class:`Temperature ` + The ending temperature. + + temperature : str, :class:`Temperature ` + The equilibration temperature. This takes precedence of over the other temperatures, i.e. to run at fixed temperature. + + pressure : str, :class:`Pressure ` + The pressure. Pass pressure=None to use the NVT ensemble. + + thermostat_time_constant : str, :class:`Time ` + Time constant for thermostat coupling. + + report_interval : int + The frequency at which statistics are recorded. (In integration steps.) + + restart_interval : int + The frequency at which restart configurations and trajectory + + core_alignment : bool + Whether to use rigid core restraints to align the two ligands. + + align_k_distance : int, float, str, :class:`GeneralUnit ` + The force constant for the distance portion of the alignment restraint (kcal/(mol A^2). + + align_k_theta : int, float, str, :class:`Energy ` + The force constant for the angular portion of the alignment restaint (kcal/mol). + + align_k_psi : int, float, str, :class:`Energy ` + The force constant for the dihedral portion of the alignment restraint (kcal/mol). + + com_distance_restraint : bool + Whether to use a center of mass distance restraint. + This restraint applies to the protein/host and both ligands, and + is used to maintain the relative positions of all of them. + + com_k : int, float, str, :class:`GeneralUnit ` + The force constant for the center of mass distance restraint (kcal/mol/A^2). + + com_restraint_width : int, float, str, :class:`Length + The width (tolerance) of the center of mass distance restraint (A). + + restraint : str, [int] + The type of restraint to perform. This should be one of the + following options: + "backbone" + Protein backbone atoms. The matching is done by a name + template, so is unreliable on conversion between + molecular file formats. + "heavy" + All non-hydrogen atoms that aren't part of water + molecules or free ions. + "all" + All atoms that aren't part of water molecules or free + ions. + Alternatively, the user can pass a list of atom indices for + more fine-grained control. If None, then no restraints are used. + + force_constant : float, :class:`GeneralUnit ` + The force constant for the restraint potential (kcal/(mol A^2). + + positional_restraint_width : float, :class:`Length ` + The width of the flat-bottom potential used for coordinate restraint in Angstroms. + + pos_restrained_atoms : [int] + The atoms to be restrained. + + soft_core_umax : int, float, str, :class:`Energy ` + The Umax value for the ATM softcore potential (kcal/mol). + + soft_core_u0 : int, float, str, :class:`Energy ` + The uh value for the ATM softcore potential (kcal/mol). + + soft_core_a : int, float, str, :class:`Energy ` + The a value for the ATM softcore potential. + + + use_atm_force : bool + Whether to apply the ATM force within the equilibration protocol. + + direction : str + The direction of the equilibration. Ignored if use_atm_force is False. + + lambda1 : float + The lambda1 value for the ATM force. Ignored if use_atm_force is False. + + lambda2 : float + The lambda2 value for the ATM force. Ignored if use_atm_force is False. + + alpha : int, float, str, :class:`Energy ` + The alpha value for the ATM force. Ignored if use_atm_force is False. + Value in kcal/mol. + + uh : int, float, str, :class:`Energy ` + The uh value for the ATM force. Ignored if use_atm_force is False. + Value in kcal/mol. + + W0 : int, float, str, :class:`Energy ` + The W0 value for the ATM force. Ignored if use_atm_force is False. + Value in kcal/mol. + """ + super().__init__( + system=system, + data=data, + core_alignment=core_alignment, + com_distance_restraint=com_distance_restraint, + com_k=com_k, + com_restraint_width=com_restraint_width, + restraint=restraint, + force_constant=force_constant, + positional_restraint_width=positional_restraint_width, + align_k_distance=align_k_distance, + align_k_theta=align_k_theta, + align_k_psi=align_k_psi, + soft_core_umax=soft_core_umax, + soft_core_u0=soft_core_u0, + soft_core_a=soft_core_a, + ) + # Store + self.setTimestep(timestep) + + self.setRuntime(runtime) + # Constant temperature equilibration. + if temperature is not None: + self.setStartTemperature(temperature) + self.setEndTemperature(temperature) + self._is_const_temp = True + + # Heating / cooling simulation. + else: + self._is_const_temp = False + + # Set the start temperature. + self.setStartTemperature(temperature_start) + + # Set the final temperature. + self.setEndTemperature(temperature_end) + + # Constant temperature simulation. + if self._temperature_start == self._temperature_end: + self._is_const_temp = True + + # Set the system pressure. + if pressure is not None: + self.setPressure(pressure) + else: + self._pressure = None + + self.setThermostatTimeConstant(thermostat_time_constant) + + self.setReportInterval(report_interval) + + self.setRestartInterval(restart_interval) + + self.setUseATMForce(use_atm_force) + + self.setDirection(direction) + + self.setLambda1(lambda1) + + self.setLambda2(lambda2) + + self.setAlpha(alpha) + + self.setUh(uh) + + self.setW0(W0) + + def getTimeStep(self): + """ + Return the time step. + + Returns + ------- + + time : :class:`Time ` + The integration time step. + """ + return self._timestep + + def setTimestep(self, timestep): + """ + Set the time step. + + Parameters + ---------- + + time : str, :class:`Time ` + The integration time step. + """ + if isinstance(timestep, str): + try: + self._timestep = _Types.Time(timestep) + except: + raise ValueError("Unable to parse 'timestep' string.") from None + elif isinstance(timestep, _Types.Time): + self._timestep = timestep + else: + raise TypeError( + "'timestep' must be of type 'str' or 'BioSimSpace.Types.Time'" + ) + + def getRunTime(self): + """ + Return the running time. + + Returns + ------- + + runtime : :class:`Time ` + The simulation run time. + """ + return self._runtime + + def setRuntime(self, runtime): + """ + Set the running time. + + Parameters + ---------- + + runtime : str, :class:`Time ` + The simulation run time. + """ + if isinstance(runtime, str): + try: + self._runtime = _Types.Time(runtime) + except: + raise ValueError("Unable to parse 'runtime' string.") from None + elif isinstance(runtime, _Types.Time): + self._runtime = runtime + else: + raise TypeError( + "'runtime' must be of type 'str' or 'BioSimSpace.Types.Time'" + ) + + def getStartTemperature(self): + """ + Return the starting temperature. + + Returns + ------- + + temperature : :class:`Temperature ` + The starting temperature. + """ + return self._temperature_start + + def setStartTemperature(self, temperature): + """ + Set the starting temperature. + + Parameters + ---------- + + temperature : str, :class:`Temperature ` + The starting temperature. + """ + + if isinstance(temperature, str): + try: + temperature = _Types.Temperature(temperature) + except: + raise ValueError("Unable to parse 'temperature' string.") from None + elif not isinstance(temperature, _Types.Temperature): + raise TypeError( + "'temperature' must be of type 'str' or 'BioSimSpace.Types.Temperature'" + ) + + if _math.isclose(temperature.kelvin().value(), 0, rel_tol=1e-6): + temperature._value = 0.01 + self._temperature_start = temperature + + def getEndTemperature(self): + """ + Return the final temperature. + + Returns + ------- + + temperature : :class:`Temperature ` + The final temperature. + """ + return self._temperature_end + + def setEndTemperature(self, temperature): + """ + Set the final temperature. + + Parameters + ---------- + + temperature : str, :class:`Temperature ` + The final temperature. + """ + if isinstance(temperature, str): + try: + temperature = _Types.Temperature(temperature) + except: + raise ValueError("Unable to parse 'temperature' string.") from None + elif not isinstance(temperature, _Types.Temperature): + raise TypeError( + "'temperature' must be of type 'str' or 'BioSimSpace.Types.Temperature'" + ) + + if _math.isclose(temperature.kelvin().value(), 0, rel_tol=1e-6): + temperature._value = 0.01 + self._temperature_end = temperature + + def getPressure(self): + """ + Return the pressure. + + Returns + ------- + + pressure : :class:`Pressure ` + The pressure. + """ + return self._pressure + + def setPressure(self, pressure): + """ + Set the pressure. + + Parameters + ---------- + + pressure : str, :class:`Pressure ` + The pressure. + """ + if isinstance(pressure, str): + try: + self._pressure = _Types.Pressure(pressure) + except: + raise ValueError("Unable to parse 'pressure' string.") from None + elif isinstance(pressure, _Types.Pressure): + self._pressure = pressure + else: + raise TypeError( + "'pressure' must be of type 'str' or 'BioSimSpace.Types.Pressure'" + ) + + def getThermostatTimeConstant(self): + """ + Return the time constant for the thermostat. + + Returns + ------- + + runtime : :class:`Time ` + The time constant for the thermostat. + """ + return self._thermostat_time_constant + + def setThermostatTimeConstant(self, thermostat_time_constant): + """ + Set the time constant for the thermostat. + + Parameters + ---------- + + thermostat_time_constant : str, :class:`Time ` + The time constant for the thermostat. + """ + if isinstance(thermostat_time_constant, str): + try: + self._thermostat_time_constant = _Types.Time(thermostat_time_constant) + except: + raise ValueError( + "Unable to parse 'thermostat_time_constant' string." + ) from None + elif isinstance(thermostat_time_constant, _Types.Time): + self._thermostat_time_constant = thermostat_time_constant + else: + raise TypeError( + "'thermostat_time_constant' must be of type 'BioSimSpace.Types.Time'" + ) + + def getReportInterval(self): + """ + Return the interval between reporting statistics. (In integration steps.). + Returns + ------- + report_interval : int + The number of integration steps between reporting statistics. + """ + return self._report_interval + + def setReportInterval(self, report_interval): + """ + Set the interval at which statistics are reported. (In integration steps.). + + Parameters + ---------- + + report_interval : int + The number of integration steps between reporting statistics. + """ + if not type(report_interval) is int: + raise TypeError("'report_interval' must be of type 'int'") + + if report_interval <= 0: + _warnings.warn("'report_interval' must be positive. Using default (100).") + report_interval = 100 + + self._report_interval = report_interval + + def getRestartInterval(self): + """ + Return the interval between saving restart confiugrations, and/or + trajectory frames. (In integration steps.). + + Returns + ------- + + restart_interval : int + The number of integration steps between saving restart + configurations and/or trajectory frames. + """ + return self._restart_interval + + def setRestartInterval(self, restart_interval): + """ + Set the interval between saving restart confiugrations, and/or + trajectory frames. (In integration steps.). + + Parameters + ---------- + + restart_interval : int + The number of integration steps between saving restart + configurations and/or trajectory frames. + """ + if not type(restart_interval) is int: + raise TypeError("'restart_interval' must be of type 'int'") + + if restart_interval <= 0: + _warnings.warn("'restart_interval' must be positive. Using default (500).") + restart_interval = 500 + + self._restart_interval = restart_interval + + def getUseATMForce(self): + """ + Return the use_atm_force flag. + + Returns + ------- + + use_atm_force : bool + Whether to apply the ATM force within the equilibration protocol. + """ + return self._use_atm_force + + def setUseATMForce(self, use_atm_force): + """ + Set the use_atm_force flag. + + Parameters + ---------- + + use_atm_force : bool + Whether to apply the ATM force within the equilibration protocol. + """ + if not isinstance(use_atm_force, bool): + raise TypeError("'use_atm_force' must be of type 'bool'") + self._use_atm_force = use_atm_force + + def getDirection(self): + """ + Return the direction of the equilibration. + + Returns + ------- + + direction : str + The direction of the equilibration. Ignored if use_atm_force is False. + """ + return self._direction + + def setDirection(self, direction): + """ + Set the direction of the equilibration. + + Parameters + ---------- + + direction : str + The direction of the equilibration. Ignored if use_atm_force is False. + """ + if int(direction) != 1 and int(direction) != -1: + raise TypeError("'direction' must have a value of 1 or -1") + self._direction = int(direction) + + def getLambda1(self): + """ + Return the lambda1 value for the ATM force. + + Returns + ------- + + lambda1 : float + The lambda1 value for the ATM force. Ignored if use_atm_force is False. + """ + return self._lambda1 + + def setLambda1(self, lambda1): + """ + Set the lambda1 value for the ATM force. + + Parameters + ---------- + + lambda1 : float + The lambda1 value for the ATM force. Ignored if use_atm_force is False. + """ + if not isinstance(lambda1, (float, int)): + raise TypeError("'lambda1' must be of type 'float'") + if not 0 <= float(lambda1) <= 0.5: + raise ValueError("lambda1 must be between 0 and 0.5") + self._lambda1 = float(lambda1) + + def getLambda2(self): + """ + Return the lambda2 value for the ATM force. + + Returns + ------- + + lambda2 : float + The lambda2 value for the ATM force. Ignored if use_atm_force is False. + """ + return self._lambda2 + + def setLambda2(self, lambda2): + """ + Set the lambda2 value for the ATM force. + + Parameters + ---------- + + lambda2 : float + The lambda2 value for the ATM force. Ignored if use_atm_force is False. + """ + if not isinstance(lambda2, (float, int)): + raise TypeError("'lambda2' must be of type 'float'") + if not 0 <= float(lambda2) <= 0.5: + raise ValueError("lambda2 must be between 0 and 0.5") + self._lambda2 = float(lambda2) + + def getAlpha(self): + """ + Return the alpha value for the ATM force. + + Returns + ------- + + alpha : :class:`Energy ` + The alpha value for the ATM force in kcal/mol. Ignored if use_atm_force is False. + """ + return self._alpha + + def setAlpha(self, alpha): + """ + Set the alpha value for the ATM force. + + Parameters + ---------- + + alpha : int, float, str, :class:`Energy ` + The alpha value for the ATM force in kcal/mol. Ignored if use_atm_force is False. + """ + # Convert int to float. + if type(alpha) is int: + alpha = float(alpha) + + if isinstance(alpha, float): + # Use default units. + alpha *= _Units.Energy.kcal_per_mol + + else: + if isinstance(alpha, str): + try: + alpha = _Types._GeneralUnit(alpha) + except Exception: + raise ValueError("Unable to parse 'alpha' string.") from None + + elif not isinstance(alpha, _Types.Energy): + raise TypeError( + "'alpha' must be of type 'BioSimSpace.Types._GeneralUnit', 'str', or 'float'." + ) + + # Validate the dimensions. + if alpha.dimensions() != (1, 2, -2, 0, 0, -1, 0): + raise ValueError( + "'align_k_theta' has invalid dimensions! " + f"Expected dimensions of energy density (e.g. kcal/mol), found '{alpha.unit()}'" + ) + self._alpha = alpha + + def getUh(self): + """ + Return the uh value for the ATM force. + + Returns + ------- + + uh : :class:`Energy ` + The uh value for the ATM force in kcal/mol. Ignored if use_atm_force is False. + """ + return self._uh + + def setUh(self, uh): + """ + Set the uh value for the ATM force. + + Parameters + ---------- + + uh : int, float, str, :class:`Energy ` + The uh value for the ATM force in kcal/mol. Ignored if use_atm_force is False. + """ + # Convert int to float. + if type(uh) is int: + uh = float(uh) + + if isinstance(uh, float): + # Use default units. + uh *= _Units.Energy.kcal_per_mol + + else: + if isinstance(uh, str): + try: + uh = _Types._GeneralUnit(uh) + except Exception: + raise ValueError("Unable to parse 'uh' string.") from None + + elif not isinstance(uh, _Types.Energy): + raise TypeError( + "'uh' must be of type 'BioSimSpace.Types._GeneralUnit', 'str', or 'float'." + ) + + # Validate the dimensions. + if uh.dimensions() != (1, 2, -2, 0, 0, -1, 0): + raise ValueError( + "'align_k_theta' has invalid dimensions! " + f"Expected dimensions of energy density (e.g. kcal/mol), found '{uh.unit()}'" + ) + self._uh = uh + + def getW0(self): + """ + Return the W0 value for the ATM force. + + Returns + ------- + + W0 : :class:`Energy ` + The W0 value for the ATM force in kcal/mol. Ignored if use_atm_force is False. + """ + return self._W0 + + def setW0(self, W0): + """ + Set the W0 value for the ATM force. + + Parameters + ---------- + + W0 :int, float, str, :class:`Energy ` + The W0 value for the ATM force in kcal/mol. Ignored if use_atm_force is False. + """ + # Convert int to float. + if type(W0) is int: + W0 = float(W0) + + if isinstance(W0, float): + # Use default units. + W0 *= _Units.Energy.kcal_per_mol + + else: + if isinstance(W0, str): + try: + W0 = _Types._GeneralUnit(W0) + except Exception: + raise ValueError("Unable to parse 'W0' string.") from None + + elif not isinstance(W0, _Types.Energy): + raise TypeError( + "'W0' must be of type 'BioSimSpace.Types._GeneralUnit', 'str', or 'float'." + ) + + # Validate the dimensions. + if W0.dimensions() != (1, 2, -2, 0, 0, -1, 0): + raise ValueError( + "'align_k_theta' has invalid dimensions! " + f"Expected dimensions of energy density (e.g. kcal/mol), found '{W0.unit()}'" + ) + self._W0 = W0 + + def isConstantTemp(self): + """ + Return whether the protocol has a constant temperature. + + Returns + ------- + + is_const_temp : bool + Whether the temperature is fixed. + """ + return self._temperature_start == self._temperature_end + + @classmethod + def restraints(cls): + """ + Return a list of the supported restraint keywords. + + Returns + ------- + + restraints : [str] + A list of the supported restraint keywords. + """ + return cls._restraints.copy() + + +class ATMAnnealing(_ATM): + """Annealing protocol for ATM simulations.""" + + def __init__( + self, + system=None, + data=None, + timestep=2 * _Units.Time.femtosecond, + runtime=0.2 * _Units.Time.nanosecond, + temperature=300 * _Units.Temperature.kelvin, + pressure=1 * _Units.Pressure.atm, + thermostat_time_constant=1 * _Units.Time.picosecond, + report_interval=100, + restart_interval=100, + core_alignment=True, + com_distance_restraint=True, + com_k=25 * _Units.Energy.kcal_per_mol / _Units.Area.angstrom2, + com_restraint_width=5 * _Units.Length.angstrom, + restraint=None, + force_constant=10 * _Units.Energy.kcal_per_mol / _Units.Area.angstrom2, + positional_restraint_width=0.5 * _Units.Length.angstrom, + align_k_distance=2.5 * _Units.Energy.kcal_per_mol / _Units.Area.angstrom2, + align_k_theta=10 * _Units.Energy.kcal_per_mol, + align_k_psi=10 * _Units.Energy.kcal_per_mol, + soft_core_umax=1000 * _Units.Energy.kcal_per_mol, + soft_core_u0=500 * _Units.Energy.kcal_per_mol, + soft_core_a=0.0625, + direction=1, + lambda1=0.0, + lambda2=0.0, + alpha=0.0, + uh=0.0, + W0=0.0, + anneal_values="default", + anneal_numcycles=100, + ): + """ + Create a new annealing protocol. + + Parameters + ---------- + system : :class:`System ` + A prepared ATM system. + + data : dict + The ATM data dictionary. + + timestep : str, :class:`Time ` + The integration timestep. + + runtime : str, :class:`Time ` + The running time. + + temperature : str, :class:`Temperature ` + The temperature. + + pressure : str, :class:`Pressure ` + The pressure. Pass pressure=None to use the NVT ensemble. + + thermostat_time_constant : str, :class:`Time ` + Time constant for thermostat coupling. + + report_interval : int + The frequency at which statistics are recorded. (In integration steps.) + + restart_interval : int + The frequency at which restart configurations and trajectory + + core_alignment : bool + Whether to use rigid core restraints to align the two ligands. + + align_k_distance : int, float, str, :class:`GeneralUnit ` + The force constant for the distance portion of the alignment restraint (kcal/(mol A^2)). + + align_k_theta : int, float, str, :class:`Energy ` + The force constant for the angular portion of the alignment restaint (kcal/mol). + + align_k_psi : int, float, str, :class:`Energy ` + The force constant for the dihedral portion of the alignment restraint (kcal/mol). + + com_distance_restraint : bool + Whether to use a center of mass distance restraint. + This restraint applies to the protein/host and both ligands, and + is used to maintain the relative positions of all of them. + + com_k : int, float, str, :class:`GeneralUnit ` + The force constant for the center of mass distance restraint (kcal/mol/A^2). + + com_restraint_width : int, float, str, :class:`Length + The width (tolerance) of the center of mass distance restraint (A). + + restraint : str, [int] + The type of restraint to perform. This should be one of the + following options: + "backbone" + Protein backbone atoms. The matching is done by a name + template, so is unreliable on conversion between + molecular file formats. + "heavy" + All non-hydrogen atoms that aren't part of water + molecules or free ions. + "all" + All atoms that aren't part of water molecules or free + ions. + Alternatively, the user can pass a list of atom indices for + more fine-grained control. If None, then no restraints are used. + + force_constant : float, :class:`GeneralUnit ` + The force constant for the restraint potential. If a 'float' is + passed, then default units of 'kcal_per_mol / angstrom**2' will + be used. + + positional_restraint_width : float, :class:`Length ` + The width of the flat-bottom potential used for coordinate restraint in Angstroms. + + pos_restrained_atoms : [int] + The atoms to be restrained. + + soft_core_umax : int, float, str, :class:`Energy ` + The Umax value for the ATM softcore potential (kcal/mol). + + soft_core_u0 : int, float, str, :class:`Energy ` + The uh value for the ATM softcore potential (kcal/mol). + + soft_core_a : int, float, str, :class:`Energy ` + The a value for the ATM softcore potential. + + direction : str + The direction of the Annealing. + + lambda1 : float + The lambda1 value for the ATM force. + Superceded by any values defined in anneal_values. + + lambda2 : float + The lambda2 value for the ATM force. + Superceded by any values defined in anneal_values. + + alpha : int, float, str, :class:`Energy ` + The alpha value for the ATM force. + Value in kcal/mol. + Superceded by any values defined in anneal_values. + + uh : int, float, str, :class:`Energy ` + The uh value for the ATM force. + Value in kcal/mol. + Superceded by any values defined in anneal_values. + + W0 : int, float, str, :class:`Energy ` + The W0 value for the ATM force. + Value in kcal/mol. + Superceded by any values defined in anneal_values. + + anneal_values : dict, None, "default" + If None, then no annealing will be performed. + If "default", then lambda values will be annealed from 0 to 0.5. + If more complex annealing is required, then + a dictionary with some or all of the following keys should be given: + "lambda1_start" : float + The starting value for lambda1. + "lambda1_end" : float + The ending value for lambda1. + "lambda2_start" : float + The starting value for lambda2. + "lambda2_end" : float + The ending value for lambda2. + "alpha_start" : float + The starting value for alpha. + "alpha_end" : float + The ending value for alpha. + "uh_start" : float + The starting value for uh. + "uh_end" : float + The ending value for uh. + "W0_start" : float + The starting value for W0. + "W0_end" : float + The ending value for W0 + Any unspecified values will use their default lambda=0 value. + + anneal_numcycles : int + The number of annealing cycles to perform, defines the rate at which values are incremented. Default 100. + """ + super().__init__( + system=system, + data=data, + core_alignment=core_alignment, + com_distance_restraint=com_distance_restraint, + com_k=com_k, + com_restraint_width=com_restraint_width, + restraint=restraint, + force_constant=force_constant, + positional_restraint_width=positional_restraint_width, + align_k_distance=align_k_distance, + align_k_theta=align_k_theta, + align_k_psi=align_k_psi, + soft_core_umax=soft_core_umax, + soft_core_u0=soft_core_u0, + soft_core_a=soft_core_a, + ) + + self.setTimestep(timestep) + + self.setRuntime(runtime) + + self.setTemperature(temperature) + + # Set the system pressure. + if pressure is not None: + self.setPressure(pressure) + else: + self._pressure = None + + self.setThermostatTimeConstant(thermostat_time_constant) + + self.setReportInterval(report_interval) + + self.setRestartInterval(restart_interval) + + self.setDirection(direction) + + self.setLambda1(lambda1) + + self.setLambda2(lambda2) + + self.setAlpha(alpha) + + self.setUh(uh) + + self.setW0(W0) + + # Store the anneal values. + self.setAnnealValues(anneal_values) + + # Set the number of annealing cycles. + self.setAnnealNumCycles(anneal_numcycles) + + def getTimeStep(self): + """ + Return the time step. + + Returns + ------- + + time : :class:`Time ` + The integration time step. + """ + return self._timestep + + def setTimestep(self, timestep): + """ + Set the time step. + + Parameters + ---------- + + time : str, :class:`Time ` + The integration time step. + """ + if isinstance(timestep, str): + try: + self._timestep = _Types.Time(timestep) + except: + raise ValueError("Unable to parse 'timestep' string.") from None + elif isinstance(timestep, _Types.Time): + self._timestep = timestep + else: + raise TypeError( + "'timestep' must be of type 'str' or 'BioSimSpace.Types.Time'" + ) + + def getRunTime(self): + """ + Return the running time. + + Returns + ------- + + runtime : :class:`Time ` + The simulation run time. + """ + return self._runtime + + def setRuntime(self, runtime): + """ + Set the running time. + + Parameters + ---------- + + runtime : str, :class:`Time ` + The simulation run time. + """ + if isinstance(runtime, str): + try: + self._runtime = _Types.Time(runtime) + except: + raise ValueError("Unable to parse 'runtime' string.") from None + elif isinstance(runtime, _Types.Time): + self._runtime = runtime + else: + raise TypeError( + "'runtime' must be of type 'str' or 'BioSimSpace.Types.Time'" + ) + + def getTemperature(self): + """ + Return temperature. + + Returns + ------- + + temperature : :class:`Temperature ` + The simulation temperature. + """ + return self._temperature + + def setTemperature(self, temperature): + """ + Set the temperature. + + Parameters + ---------- + + temperature : str, :class:`Temperature ` + The simulation temperature. + """ + if isinstance(temperature, str): + try: + self._temperature = _Types.Temperature(temperature) + except: + raise ValueError("Unable to parse 'temperature' string.") from None + elif isinstance(temperature, _Types.Temperature): + self._temperature = temperature + else: + raise TypeError( + "'temperature' must be of type 'str' or 'BioSimSpace.Types.Temperature'" + ) + + def getPressure(self): + """ + Return the pressure. + + Returns + ------- + + pressure : :class:`Pressure ` + The pressure. + """ + return self._pressure + + def setPressure(self, pressure): + """ + Set the pressure. + + Parameters + ---------- + + pressure : str, :class:`Pressure ` + The pressure. + """ + if isinstance(pressure, str): + try: + self._pressure = _Types.Pressure(pressure) + except: + raise ValueError("Unable to parse 'pressure' string.") from None + elif isinstance(pressure, _Types.Pressure): + self._pressure = pressure + else: + raise TypeError( + "'pressure' must be of type 'str' or 'BioSimSpace.Types.Pressure'" + ) + + def getThermostatTimeConstant(self): + """ + Return the time constant for the thermostat. + + Returns + ------- + + runtime : :class:`Time ` + The time constant for the thermostat. + """ + return self._thermostat_time_constant + + def setThermostatTimeConstant(self, thermostat_time_constant): + """ + Set the time constant for the thermostat. + + Parameters + ---------- + + thermostat_time_constant : str, :class:`Time ` + The time constant for the thermostat. + """ + if isinstance(thermostat_time_constant, str): + try: + self._thermostat_time_constant = _Types.Time(thermostat_time_constant) + except: + raise ValueError( + "Unable to parse 'thermostat_time_constant' string." + ) from None + elif isinstance(thermostat_time_constant, _Types.Time): + self._thermostat_time_constant = thermostat_time_constant + else: + raise TypeError( + "'thermostat_time_constant' must be of type 'BioSimSpace.Types.Time'" + ) + + def getReportInterval(self): + """ + Return the interval between reporting statistics. (In integration steps.). + Returns + ------- + report_interval : int + The number of integration steps between reporting statistics. + """ + return self._report_interval + + def setReportInterval(self, report_interval): + """ + Set the interval at which statistics are reported. (In integration steps.). + + Parameters + ---------- + + report_interval : int + The number of integration steps between reporting statistics. + """ + if not type(report_interval) is int: + raise TypeError("'report_interval' must be of type 'int'") + + if report_interval <= 0: + _warnings.warn("'report_interval' must be positive. Using default (100).") + report_interval = 100 + + self._report_interval = report_interval + + def getRestartInterval(self): + """ + Return the interval between saving restart confiugrations, and/or + trajectory frames. (In integration steps.). + + Returns + ------- + + restart_interval : int + The number of integration steps between saving restart + configurations and/or trajectory frames. + """ + return self._restart_interval + + def setRestartInterval(self, restart_interval): + """ + Set the interval between saving restart confiugrations, and/or + trajectory frames. (In integration steps.). + + Parameters + ---------- + + restart_interval : int + The number of integration steps between saving restart + configurations and/or trajectory frames. + """ + if not type(restart_interval) is int: + raise TypeError("'restart_interval' must be of type 'int'") + + if restart_interval <= 0: + _warnings.warn("'restart_interval' must be positive. Using default (500).") + restart_interval = 500 + + self._restart_interval = restart_interval + + def getDirection(self): + """ + Return the direction of the equilibration. + + Returns + ------- + + direction : str + The direction of the equilibration. Ignored if use_atm_force is False. + """ + return self._direction + + def setDirection(self, direction): + """ + Set the direction of the equilibration. + + Parameters + ---------- + + direction : str + The direction of the equilibration. Ignored if use_atm_force is False. + """ + if int(direction) != 1 and int(direction) != -1: + raise TypeError("'direction' must have a value of 1 or -1") + self._direction = int(direction) + + def getLambda1(self): + """ + Return the lambda1 value for the ATM force. + + Returns + ------- + + lambda1 : float + The lambda1 value for the ATM force. Ignored if use_atm_force is False. + """ + return self._lambda1 + + def setLambda1(self, lambda1): + """ + Set the lambda1 value for the ATM force. + + Parameters + ---------- + + lambda1 : float + The lambda1 value for the ATM force. Ignored if use_atm_force is False. + """ + if not isinstance(lambda1, (float, int)): + raise TypeError("'lambda1' must be of type 'float'") + if not 0 <= float(lambda1) <= 0.5: + raise ValueError("lambda1 must be between 0 and 0.5") + self._lambda1 = float(lambda1) + + def getLambda2(self): + """ + Return the lambda2 value for the ATM force. + + Returns + ------- + + lambda2 : float + The lambda2 value for the ATM force. Ignored if use_atm_force is False. + """ + return self._lambda2 + + def setLambda2(self, lambda2): + """ + Set the lambda2 value for the ATM force. + + Parameters + ---------- + + lambda2 : float + The lambda2 value for the ATM force. Ignored if use_atm_force is False. + """ + if not isinstance(lambda2, (float, int)): + raise TypeError("'lambda2' must be of type 'float'") + if not 0 <= float(lambda2) <= 0.5: + raise ValueError("lambda2 must be between 0 and 0.5") + self._lambda2 = float(lambda2) + + def getAlpha(self): + """ + Return the alpha value for the ATM force. + + Returns + ------- + + alpha : :class:`Energy ` + The alpha value for the ATM force in kcal/mol. Ignored if use_atm_force is False. + """ + return self._alpha + + def setAlpha(self, alpha): + """ + Set the alpha value for the ATM force. + + Parameters + ---------- + + alpha : int, float, str, :class:`Energy ` + The alpha value for the ATM force in kcal/mol. Ignored if use_atm_force is False. + """ + # Convert int to float. + if type(alpha) is int: + alpha = float(alpha) + + if isinstance(alpha, float): + # Use default units. + alpha *= _Units.Energy.kcal_per_mol + + else: + if isinstance(alpha, str): + try: + alpha = _Types._GeneralUnit(alpha) + except Exception: + raise ValueError("Unable to parse 'alpha' string.") from None + + elif not isinstance(alpha, _Types.Energy): + raise TypeError( + "'alpha' must be of type 'BioSimSpace.Types._GeneralUnit', 'str', or 'float'." + ) + + # Validate the dimensions. + if alpha.dimensions() != (1, 2, -2, 0, 0, -1, 0): + raise ValueError( + "'align_k_theta' has invalid dimensions! " + f"Expected dimensions of energy density (e.g. kcal/mol), found '{alpha.unit()}'" + ) + self._alpha = alpha + + def getUh(self): + """ + Return the uh value for the ATM force. + + Returns + ------- + + uh : :class:`Energy ` + The uh value for the ATM force in kcal/mol. Ignored if use_atm_force is False. + """ + return self._uh + + def setUh(self, uh): + """ + Set the uh value for the ATM force. + + Parameters + ---------- + + uh : int, float, str, :class:`Energy ` + The uh value for the ATM force in kcal/mol. Ignored if use_atm_force is False. + """ + # Convert int to float. + if type(uh) is int: + uh = float(uh) + + if isinstance(uh, float): + # Use default units. + uh *= _Units.Energy.kcal_per_mol + + else: + if isinstance(uh, str): + try: + uh = _Types._GeneralUnit(uh) + except Exception: + raise ValueError("Unable to parse 'uh' string.") from None + + elif not isinstance(uh, _Types.Energy): + raise TypeError( + "'uh' must be of type 'BioSimSpace.Types._GeneralUnit', 'str', or 'float'." + ) + + # Validate the dimensions. + if uh.dimensions() != (1, 2, -2, 0, 0, -1, 0): + raise ValueError( + "'align_k_theta' has invalid dimensions! " + f"Expected dimensions of energy density (e.g. kcal/mol), found '{uh.unit()}'" + ) + self._uh = uh + + def getW0(self): + """ + Return the W0 value for the ATM force. + + Returns + ------- + + W0 : :class:`Energy ` + The W0 value for the ATM force in kcal/mol. Ignored if use_atm_force is False. + """ + return self._W0 + + def setW0(self, W0): + """ + Set the W0 value for the ATM force. + + Parameters + ---------- + + W0 :int, float, str, :class:`Energy ` + The W0 value for the ATM force in kcal/mol. Ignored if use_atm_force is False. + """ + # Convert int to float. + if type(W0) is int: + W0 = float(W0) + + if isinstance(W0, float): + # Use default units. + W0 *= _Units.Energy.kcal_per_mol + + else: + if isinstance(W0, str): + try: + W0 = _Types._GeneralUnit(W0) + except Exception: + raise ValueError("Unable to parse 'W0' string.") from None + + elif not isinstance(W0, _Types.Energy): + raise TypeError( + "'W0' must be of type 'BioSimSpace.Types._GeneralUnit', 'str', or 'float'." + ) + + # Validate the dimensions. + if W0.dimensions() != (1, 2, -2, 0, 0, -1, 0): + raise ValueError( + "'align_k_theta' has invalid dimensions! " + f"Expected dimensions of energy density (e.g. kcal/mol), found '{W0.unit()}'" + ) + self._W0 = W0 + + def getAnnealValues(self): + """ + Return the anneal protocol. + + Returns + ------- + + anneal_protocol : dict + The anneal protocol. + """ + return self._anneal_values + + def setAnnealValues(self, anneal_values): + """ + Set the anneal protocol. + + Parameters + ---------- + + anneal_values : dict + The anneal values. + """ + + def capitalise_keys(input_dict): + # The first letter of each key needs to be captilised + # so that it can be properly passed to openMM later + capitalized_dict = {} + for key, value in input_dict.items(): + capitalized_key = key.capitalize() + capitalized_dict[capitalized_key] = value + return capitalized_dict + + if anneal_values == "default": + self._anneal_values = capitalise_keys( + { + "lambda1_start": 0, + "lambda1_end": 0.5, + "lambda2_start": 0, + "lambda2_end": 0.5, + } + ) + elif isinstance(anneal_values, dict): + # check that the given keys are valid + keys = [ + "lambda1_start", + "lambda1_end", + "lambda2_start", + "lambda2_end", + "alpha_start", + "alpha_end", + "uh_start", + "uh_end", + "W0_start", + "W0_end", + ] + for key in anneal_values: + if key not in keys: + raise ValueError( + f"The anneal values can only contain the following keys: 'lambda1_start', 'lambda1_end', 'lambda2_start', 'lambda2_end', 'alpha_start', 'alpha_end', 'uh_start', 'uh_end', 'W0_start', 'W0_end', 'runtime'. The following keys are invalid: {key}" + ) + if key == "lambda1_start" or key == "lambda1_end": + if not 0 <= float(anneal_values[key]) <= 0.5: + raise ValueError("lambda1 must be between 0 and 0.5") + if key == "lambda2_start" or key == "lambda2_end": + if not 0 <= float(anneal_values[key]) <= 0.5: + raise ValueError("lambda2 must be between 0 and 0.5") + # check that none of the other keys are negative + if ( + key != "lambda1_start" + and key != "lambda1_end" + and key != "lambda2_start" + and key != "lambda2_end" + ): + if float(anneal_values[key]) < 0: + raise ValueError(f"{key} must be greater than or equal to 0") + # also check that they are floats + if not isinstance(anneal_values[key], (float, int)): + raise TypeError(f"{key} must be of type 'float'") + self._anneal_values = capitalise_keys(anneal_values) + elif anneal_values is None: + self._anneal_values = None + + else: + raise TypeError( + "'anneal_values' must be of type 'dict', 'None', or 'default'" + ) + + def getAnnealNumCycles(self): + """ + Return the number of annealing cycles. + + Returns + ------- + + anneal_numcycles : int + The number of annealing cycles. + """ + return self._anneal_numcycles + + def setAnnealNumCycles(self, anneal_numcycles): + """ + Set the number of annealing cycles. + + Parameters + ---------- + + anneal_numcycles : int + The number of annealing cycles. + """ + if isinstance(anneal_numcycles, int): + self._anneal_numcycles = anneal_numcycles + else: + raise TypeError("'anneal_numcycles' must be of type 'int'") + + def _set_current_index(self, index): + """ + The current index of the window. + In annealing protocols this should not be touched by the user. + + Parameters + ---------- + index : int + The index of the current lambda window. + """ + if index < 0: + raise ValueError("index must be positive") + if not isinstance(index, int): + raise TypeError("index must be an integer") + self._current_index = index + + def _get_window_index(self): + """ + A function to get the index of the current lambda window. + + Returns + ------- + index : int + The index of the current lambda window. + """ + try: + return self._current_index + except: + return None + + +class ATMProduction(_ATM): + """Production protocol for ATM simulations.""" + + def __init__( + self, + system=None, + data=None, + timestep=2 * _Units.Time.femtosecond, + runtime=1.0 * _Units.Time.nanosecond, + temperature=300 * _Units.Temperature.kelvin, + pressure=1 * _Units.Pressure.atm, + thermostat_time_constant=1 * _Units.Time.picosecond, + report_interval=100, + restart_interval=100, + restart=False, + core_alignment=True, + com_distance_restraint=True, + com_k=25 * _Units.Energy.kcal_per_mol / _Units.Area.angstrom2, + com_restraint_width=5 * _Units.Length.angstrom, + restraint=None, + force_constant=10 * _Units.Energy.kcal_per_mol / _Units.Area.angstrom2, + positional_restraint_width=0.5 * _Units.Length.angstrom, + num_lambda=22, + direction=None, + lambda1=None, + lambda2=None, + alpha=None, + uh=None, + W0=None, + align_k_distance=2.5 * _Units.Energy.kcal_per_mol / _Units.Area.angstrom2, + align_k_theta=10 * _Units.Energy.kcal_per_mol, + align_k_psi=10 * _Units.Energy.kcal_per_mol, + soft_core_umax=100 * _Units.Energy.kcal_per_mol, + soft_core_u0=50 * _Units.Energy.kcal_per_mol, + soft_core_a=0.0625, + analysis_method="UWHAM", + ): + """ + Create a new production protocol. + + Parameters + ---------- + system : :class:`System ` + A prepared ATM system. + + data : dict + The ATM data dictionary. + + timestep : str, :class:`Time ` + The integration timestep. + + runtime : str, :class:`Time ` + The running time. + + temperature : str, :class:`Temperature ` + The temperature. + + pressure : str, :class:`Pressure ` + The pressure. Pass pressure=None to use the NVT ensemble. + + thermostat_time_constant : str, :class:`Time ` + Time constant for thermostat coupling. + + report_interval : int + The frequency at which statistics are recorded. (In integration steps.) + + restart_interval : int + The frequency at which restart configurations and trajectory + + core_alignment : bool + Whether to use rigid core restraints to align the two ligands. + + align_k_distance : int, float, str, :class:`GeneralUnit ` + The force constant for the distance portion of the alignment restraint (kcal/(mol A^2)). + + align_k_theta : int, float, str, :class:`Energy ` + The force constant for the angular portion of the alignment restaint (kcal/mol). + + align_k_psi : int, float, str, :class:`Energy ` + The force constant for the dihedral portion of the alignment restraint (kcal/mol). + + com_distance_restraint : bool + Whether to use a center of mass distance restraint. + This restraint applies to the protein/host and both ligands, and + is used to maintain the relative positions of all of them. + + com_k : int, float, str, :class:`GeneralUnit ` + The force constant for the center of mass distance restraint (kcal/mol/A^2). + + com_restraint_width : int, float, str, :class:`Length + The width (tolerance) of the center of mass distance restraint (A). + + restraint : str, [int] + The type of restraint to perform. This should be one of the + following options: + "backbone" + Protein backbone atoms. The matching is done by a name + template, so is unreliable on conversion between + molecular file formats. + "heavy" + All non-hydrogen atoms that aren't part of water + molecules or free ions. + "all" + All atoms that aren't part of water molecules or free + ions. + Alternatively, the user can pass a list of atom indices for + more fine-grained control. If None, then no restraints are used. + + force_constant : float, :class:`GeneralUnit ` + The force constant for the restraint potential. If a 'float' is + passed, then default units of 'kcal_per_mol / angstrom**2' will + be used. + + positional_restraint_width : :class:`Length `, float + The width of the flat-bottom potential used for coordinate restraint in Angstroms. + + pos_restrained_atoms : [int] + The atoms to be restrained. + + soft_core_umax : int, float, str, :class:`Energy ` + The Umax value for the ATM softcore potential (kcal/mol). + + soft_core_u0 : int, float, str, :class:`Energy ` + The uh value for the ATM softcore potential (kcal/mol). + + soft_core_a : int, float, str, :class:`Energy ` + The a value for the ATM softcore potential. + + restart : bool + Whether this is a continuation of a previous simulation. + + num_lambda : int + The number of lambda values. This will be used to set the window-dependent + ATM parameters, unless they are explicitly set by the user. + + lambdas : [float] + The lambda values. + + direction : [int] + The direction values. Must be either 1 (forwards) or -1 (backwards). + + lambda1 : [float] + The lambda1 values. + + lambda2 : [float] + The lambda2 values. + + alpha : [int], float, str, :class:`Energy ` + The alpha values. + + uh : [int], float, str, :class:`Energy ` + The uh values. + + W0 : [int], float, str, :class:`Energy ` + The W0 values. + + analysis_method : str + The method to use for analysis. Options are "UWHAM", "MBAR" or "both" + This affects the output files and the analysis that is performed. + USE of "UWHAM" is strongly recommended, "MBAR" analysis is still experimental. + """ + super().__init__( + system=system, + data=data, + core_alignment=core_alignment, + com_distance_restraint=com_distance_restraint, + com_k=com_k, + com_restraint_width=com_restraint_width, + restraint=restraint, + force_constant=force_constant, + positional_restraint_width=positional_restraint_width, + align_k_distance=align_k_distance, + align_k_theta=align_k_theta, + align_k_psi=align_k_psi, + soft_core_umax=soft_core_umax, + soft_core_u0=soft_core_u0, + soft_core_a=soft_core_a, + ) + + self.setTimestep(timestep) + + self.setRuntime(runtime) + + self.setTemperature(temperature) + + # Set the system pressure. + if pressure is not None: + self.setPressure(pressure) + else: + self._pressure = None + + self.setThermostatTimeConstant(thermostat_time_constant) + + self.setReportInterval(report_interval) + + self.setRestartInterval(restart_interval) + + # Set the restart flag. + self.setRestart(restart) + # Set the number of lambda values. + # If other window-dependent parameters are not set, then set them to + # sensible defaults. + self.setNumLambda(num_lambda) + + # Store the direction values. + self.setDirection(direction) + + # Store the lambda1 values. + self.setLambda1(lambda1) + + # Store the lambda2 values. + self.setLambda2(lambda2) + + # Store the alpha values. + self.setAlpha(alpha) + + # Store the uh values. + self.setUh(uh) + + # Store the W0 values. + self.setW0(W0) + + self._set_lambda_values() + + self.setAnalysisMethod(analysis_method) + + def getTimeStep(self): + """ + Return the time step. + + Returns + ------- + + time : :class:`Time ` + The integration time step. + """ + return self._timestep + + def setTimestep(self, timestep): + """ + Set the time step. + + Parameters + ---------- + + time : str, :class:`Time ` + The integration time step. + """ + if isinstance(timestep, str): + try: + self._timestep = _Types.Time(timestep) + except: + raise ValueError("Unable to parse 'timestep' string.") from None + elif isinstance(timestep, _Types.Time): + self._timestep = timestep + else: + raise TypeError( + "'timestep' must be of type 'str' or 'BioSimSpace.Types.Time'" + ) + + def getRunTime(self): + """ + Return the running time. Set the same as other OpenMM protocols - really should be Runtime not RunTime. + + Returns + ------- + + runtime : :class:`Time ` + The simulation run time. + """ + return self._runtime + + def setRuntime(self, runtime): + """ + Set the running time. + + Parameters + ---------- + + runtime : str, :class:`Time ` + The simulation run time. + """ + if isinstance(runtime, str): + try: + self._runtime = _Types.Time(runtime) + except: + raise ValueError("Unable to parse 'runtime' string.") from None + elif isinstance(runtime, _Types.Time): + self._runtime = runtime + else: + raise TypeError( + "'runtime' must be of type 'str' or 'BioSimSpace.Types.Time'" + ) + + def getTemperature(self): + """ + Return temperature. + + Returns + ------- + + temperature : :class:`Temperature ` + The simulation temperature. + """ + return self._temperature + + def setTemperature(self, temperature): + """ + Set the temperature. + + Parameters + ---------- + + temperature : str, :class:`Temperature ` + The simulation temperature. + """ + if isinstance(temperature, str): + try: + self._temperature = _Types.Temperature(temperature) + except: + raise ValueError("Unable to parse 'temperature' string.") from None + elif isinstance(temperature, _Types.Temperature): + self._temperature = temperature + else: + raise TypeError( + "'temperature' must be of type 'str' or 'BioSimSpace.Types.Temperature'" + ) + + def getPressure(self): + """ + Return the pressure. + + Returns + ------- + + pressure : :class:`Pressure ` + The pressure. + """ + return self._pressure + + def setPressure(self, pressure): + """ + Set the pressure. + + Parameters + ---------- + + pressure : str, :class:`Pressure ` + The pressure. + """ + if isinstance(pressure, str): + try: + self._pressure = _Types.Pressure(pressure) + except: + raise ValueError("Unable to parse 'pressure' string.") from None + elif isinstance(pressure, _Types.Pressure): + self._pressure = pressure + else: + raise TypeError( + "'pressure' must be of type 'str' or 'BioSimSpace.Types.Pressure'" + ) + + def getThermostatTimeConstant(self): + """ + Return the time constant for the thermostat. + + Returns + ------- + + runtime : :class:`Time ` + The time constant for the thermostat. + """ + return self._thermostat_time_constant + + def setThermostatTimeConstant(self, thermostat_time_constant): + """ + Set the time constant for the thermostat. + + Parameters + ---------- + + thermostat_time_constant : str, :class:`Time ` + The time constant for the thermostat. + """ + if isinstance(thermostat_time_constant, str): + try: + self._thermostat_time_constant = _Types.Time(thermostat_time_constant) + except: + raise ValueError( + "Unable to parse 'thermostat_time_constant' string." + ) from None + elif isinstance(thermostat_time_constant, _Types.Time): + self._thermostat_time_constant = thermostat_time_constant + else: + raise TypeError( + "'thermostat_time_constant' must be of type 'BioSimSpace.Types.Time'" + ) + + def getReportInterval(self): + """ + Return the interval between reporting statistics. (In integration steps.). + Returns + ------- + report_interval : int + The number of integration steps between reporting statistics. + """ + return self._report_interval + + def setReportInterval(self, report_interval): + """ + Set the interval at which statistics are reported. (In integration steps.). + + Parameters + ---------- + + report_interval : int + The number of integration steps between reporting statistics. + """ + if not type(report_interval) is int: + raise TypeError("'report_interval' must be of type 'int'") + + if report_interval <= 0: + _warnings.warn("'report_interval' must be positive. Using default (100).") + report_interval = 100 + + self._report_interval = report_interval + + def getRestartInterval(self): + """ + Return the interval between saving restart confiugrations, and/or + trajectory frames. (In integration steps.). + + Returns + ------- + + restart_interval : int + The number of integration steps between saving restart + configurations and/or trajectory frames. + """ + return self._restart_interval + + def setRestartInterval(self, restart_interval): + """ + Set the interval between saving restart confiugrations, and/or + trajectory frames. (In integration steps.). + + Parameters + ---------- + + restart_interval : int + The number of integration steps between saving restart + configurations and/or trajectory frames. + """ + if not type(restart_interval) is int: + raise TypeError("'restart_interval' must be of type 'int'") + + if restart_interval <= 0: + _warnings.warn("'restart_interval' must be positive. Using default (500).") + restart_interval = 500 + + self._restart_interval = restart_interval + + def isRestart(self): + """ + Return whether this restart simulation. + + Returns + ------- + + is_restart : bool + Whether this is a restart simulation. + """ + return self._restart + + def setRestart(self, restart): + """ + Set the restart flag. + + Parameters + ---------- + + restart : bool + Whether this is a restart simulation. + """ + if isinstance(restart, bool): + self._restart = restart + else: + _warnings.warn("Non-boolean restart flag. Defaulting to False!") + self._restart = False + + def getNumLambda(self): + """ + Return the number of lambda values. + + Returns + ------- + + num_lambda : int + The number of lambda values. + """ + return self._num_lambda + + def setNumLambda(self, num_lambda): + """ + Set the number of lambda values. + + Parameters + ---------- + + num_lambda : int + The number of lambda values. + """ + if isinstance(num_lambda, int) and num_lambda > 0: + if num_lambda % 2 != 0: + _warnings.warn( + "Warning: The ATM protocol is optimised for an even number of lambda values. " + "Unknown behaviour may occur if using an odd number of lambda values." + ) + self._num_lambda = num_lambda + self._set_lambda_values() + else: + raise TypeError("'num_lambda' must be of type 'int'") + + def getDirection(self): + """ + Return the direction values. + + Returns + ------- + + lambdas : [float] + The directions. + """ + return self._directions + + def setDirection(self, directions): + """ + Set the direction values. + + Parameters + ---------- + + directions : [int] + The directions. + """ + if isinstance(directions, list): + if len(directions) != self._num_lambda: + raise ValueError( + "'directions' must have the same length as 'num_lambda'" + ) + if all(item == 1 or item == -1 for item in directions): + self._directions = directions + else: + raise ValueError("all entries in 'directions' must be either 1 or -1") + elif directions is None: + self._directions = [1] * _math.floor(self._num_lambda / 2) + [ + -1 + ] * _math.ceil(self._num_lambda / 2) + else: + raise TypeError("'directions' must be of type 'list' or 'None'") + + def getLambda1(self): + """ + Return the lambda1 values. + + Returns + ------- + + lambda1 : [float] + The lambda1 values. + """ + return self._lambda1 + + def setLambda1(self, lambda1): + """ + Set the lambda1 values. + + Parameters + ---------- + + lambda1 : [float] + The lambda1 values. + """ + if isinstance(lambda1, list): + if len(lambda1) != self._num_lambda: + raise ValueError("'lambda1' must have the same length as 'num_lambda'") + if all(isinstance(item, float) for item in lambda1) and all( + item <= 0.5 for item in lambda1 + ): + self._lambda1 = lambda1 + else: + raise ValueError( + "all entries in 'lambda1' must be floats with a value less than or equal to 0.5" + ) + elif lambda1 is None: + # use numpy to create a [float]s + self._lambda1 = _np.concatenate( + [ + _np.linspace(0, 0.5, _math.floor(self._num_lambda / 2)), + _np.linspace(0.5, 0, _math.ceil(self._num_lambda / 2)), + ] + ).tolist() + # Round the floats to 5 decimal places + self._lambda1 = [round(num, 5) for num in self._lambda1] + else: + raise TypeError("'lambda1' must be of type 'list'") + + def getLambda2(self): + """ + Return the lambda2 values. + + Returns + ------- + + lambda2 : [float] + The lambda2 values. + """ + return self._lambda2 + + def setLambda2(self, lambda2): + """ + Set the lambda2 values. + + Parameters + ---------- + + lambda2 : [float] + The lambda2 values. + """ + if isinstance(lambda2, list): + if len(lambda2) != self._num_lambda: + raise ValueError("'lambda2' must have the same length as 'num_lambda'") + if all(isinstance(item, float) for item in lambda2) and all( + item <= 0.5 for item in lambda2 + ): + if len(lambda2) != len(self._lambda1): + raise ValueError( + "'lambda2' and 'lambda1' must have the same length" + ) + self._lambda2 = lambda2 + else: + raise ValueError("all entries in 'lambda2' must be floats") + elif lambda2 is None: + # use numpy to create a [float]s + self._lambda2 = _np.concatenate( + [ + _np.linspace(0, 0.5, _math.floor(self._num_lambda / 2)), + _np.linspace(0.5, 0, _math.ceil(self._num_lambda / 2)), + ] + ).tolist() + # Round the floats to 5 decimal places + self._lambda2 = [round(num, 5) for num in self._lambda2] + else: + raise TypeError("'lambda2' must be of type 'list'") + + def getAlpha(self): + """ + Return the alpha values. + + Returns + ------- + + alpha : [:class:`Energy ] + The alpha values in kcal/mol. + """ + return self._alpha + + def setAlpha(self, alpha): + """ + Set the alpha values. + + Parameters + ---------- + + alpha : [`Energy ] or [int], [float], [str] + The alpha values in kcal/mol. + """ + if isinstance(alpha, list): + if len(alpha) != self._num_lambda: + raise ValueError("'alpha' must have the same length as 'num_lambda'") + alpha_fin = [] + for a in alpha: + # Convert int to float. + if type(a) is int: + a = float(a) + a *= _Units.Energy.kcal_per_mol + + elif isinstance(a, float): + # Use default units. + a *= _Units.Energy.kcal_per_mol + + else: + if isinstance(a, str): + try: + a = _Types._GeneralUnit(a) + except Exception: + raise ValueError( + "Unable to parse 'alpha' string." + ) from None + + elif not isinstance(a, _Types.Energy): + raise TypeError( + "'alpha' must be of type 'BioSimSpace.Types.Energy', 'str', or 'float'." + ) + + # Validate the dimensions. + if a.dimensions() != (1, 2, -2, 0, 0, -1, 0): + raise ValueError( + "'alpha' has invalid dimensions! " + f"Expected dimensions of energy density (e.g. kcal/mol), found '{a.unit()}'" + ) + alpha_fin.append(a) + self._alpha = alpha_fin + elif alpha is None: + self._alpha = [0.00 * _Units.Energy.kcal_per_mol] * self._num_lambda + else: + raise TypeError("'alpha' must be of type 'list' or None") + + def getUh(self): + """ + Return the uh values. + + Returns + ------- + + uh : [:class:`Energy ] + The uh values in kcal/mol. + """ + return self._uh + + def setUh(self, uh): + """ + Set the uh values. + + Parameters + ---------- + + uh : [:class:`Energy ] + The uh values in kcal/mol. + """ + if isinstance(uh, list): + if len(uh) != self._num_lambda: + raise ValueError("'uh' must have the same length as 'num_lambda'") + uh_fin = [] + for u in uh: + # Convert int to float. + if type(u) is int: + u = float(u) + u *= _Units.Energy.kcal_per_mol + + if isinstance(u, float): + # Use default units. + u *= _Units.Energy.kcal_per_mol + + else: + if isinstance(u, str): + try: + u = _Types._GeneralUnit(u) + except Exception: + raise ValueError( + "Unable to parse 'alpha' string." + ) from None + + elif not isinstance(u, _Types.Energy): + raise TypeError( + "'alpha' must be of type 'BioSimSpace.Types._GeneralUnit', 'str', or 'float'." + ) + + # Validate the dimensions. + if u.dimensions() != (1, 2, -2, 0, 0, -1, 0): + raise ValueError( + "'alpha' has invalid dimensions! " + f"Expected dimensions of energy density (e.g. kcal/mol), found '{u.unit()}'" + ) + uh_fin.append(u) + self._uh = uh_fin + elif uh is None: + self._uh = [0.00 * _Units.Energy.kcal_per_mol] * self._num_lambda + else: + raise TypeError("'uh' must be of type 'list'") + + def getW0(self): + """ + Return the W0 values. + + Returns + ------- + + W0 : [:class:`Energy ] + The W0 values in kcal/mol. + """ + return self._W0 + + def setW0(self, W0): + """ + Set the W0 values. + + Parameters + ---------- + + W0 : [:class:`Energy ] or [int], [float], [str] + The W0 values in kcal/mol. + """ + if isinstance(W0, list): + if len(W0) != self._num_lambda: + raise ValueError("'W0' must have the same length as 'num_lambda'") + W0_fin = [] + for w in W0: + # Convert int to float. + if type(w) is int: + w = float(w) + w *= _Units.Energy.kcal_per_mol + + if isinstance(w, float): + # Use default units. + w *= _Units.Energy.kcal_per_mol + + else: + if isinstance(w, str): + try: + w = _Types._GeneralUnit(w) + except Exception: + raise ValueError( + "Unable to parse 'alpha' string." + ) from None + + elif not isinstance(w, _Types.Energy): + raise TypeError( + "'alpha' must be of type 'BioSimSpace.Types._GeneralUnit', 'str', or 'float'." + ) + + # Validate the dimensions. + if w.dimensions() != (1, 2, -2, 0, 0, -1, 0): + raise ValueError( + "'alpha' has invalid dimensions! " + f"Expected dimensions of energy density (e.g. kcal/mol), found '{w.unit()}'" + ) + W0_fin.append(w) + self._W0 = W0_fin + elif W0 is None: + self._W0 = [0.00 * _Units.Energy.kcal_per_mol] * self._num_lambda + else: + raise TypeError("'W0' must be of type 'list'") + + def _set_lambda_values(self): + # Internal function to set the 'master lambda' + # This lambda value serves as the master for all other window-dependent parameters + self._lambda_values = _np.linspace(0, 1, self._num_lambda).tolist() + + def _get_lambda_values(self): + # Internal function to get the 'master lambda' + # This lambda value serves as the master for all other window-dependent parameters + try: + return self._lambda_values + except: + return None + + def setAnalysisMethod(self, analysis_method): + """Set the method that will be used for analysis of the simulation results. + This will change the output files that are generated. + + Parameters + ---------- + analysis_method : str + The method to use for analysis. Options are "UWHAM", "MBAR" or "both" + This affects the output files and the analysis that is performed. + USE of "UWHAM" is strongly recommended, "MBAR" analysis is still experimental. + """ + allowed_methods = ["UWHAM", "MBAR", "both"] + if analysis_method in allowed_methods: + self._analysis_method = analysis_method + else: + raise ValueError(f"analysis_method must be one of {allowed_methods}") + + def getAnalysisMethod(self): + return self._analysis_method + + def set_current_index(self, index): + """ + A function to set the index of the current lambda window. + Used internally to set the values for all lambda-dependent parameters. + Take care when using this function as it can lead to unexpected behaviour if not used correctly. + + Parameters + ---------- + index : int + The index of the current lambda window. + """ + if index < 0: + raise ValueError("index must be positive") + if index >= len(self._lambda1): + raise ValueError( + "index must be less than the number of lambda1 values (len(lambda1))" + ) + if not isinstance(index, int): + raise TypeError("index must be an integer") + self._current_index = index + + def get_window_index(self): + """ + A function to get the index of the current lambda window. + + Returns + ------- + index : int + The index of the current lambda window. + """ + try: + return self._current_index + except: + return None diff --git a/python/BioSimSpace/Sandpit/Exscientia/_SireWrappers/_system.py b/python/BioSimSpace/Sandpit/Exscientia/_SireWrappers/_system.py index 283c84ef8..282d5edfd 100644 --- a/python/BioSimSpace/Sandpit/Exscientia/_SireWrappers/_system.py +++ b/python/BioSimSpace/Sandpit/Exscientia/_SireWrappers/_system.py @@ -1238,8 +1238,9 @@ def rotateBoxVectors( from sire.system import System - # Create a cursor. - cursor = System(self._sire_object).cursor() + # Create a cursor for the non-perturbable molecules. + system = System(self._sire_object) + cursor = system["not property is_perturbable"].cursor() # Rotate all vector properties. @@ -1267,8 +1268,15 @@ def rotateBoxVectors( except: pass + # Update the molecules in the system. + system.update(cursor.commit()) + self._sire_object = system._system + # Now deal with any perturbable molecules. if self.nPerturbableMolecules() > 0: + # Create a cursor for the perturbable molecules. + cursor = system["property is_perturbable"].cursor() + # Coordinates. try: prop_name = property_map.get("coordinates", "coordinates") + "0" @@ -1307,8 +1315,9 @@ def rotateBoxVectors( except: pass - # Commit the changes. - self._sire_object = cursor.commit()._system + # Update the perturbable molecules in the system. + system.update(cursor.commit()) + self._sire_object = system._system def reduceBoxVectors(self, bias=0, property_map={}): """ diff --git a/python/BioSimSpace/_SireWrappers/_system.py b/python/BioSimSpace/_SireWrappers/_system.py index c00beac2c..9cbf17010 100644 --- a/python/BioSimSpace/_SireWrappers/_system.py +++ b/python/BioSimSpace/_SireWrappers/_system.py @@ -1186,8 +1186,9 @@ def rotateBoxVectors( from sire.system import System - # Create a cursor. - cursor = System(self._sire_object).cursor() + # Create a cursor for the non-perturbable molecules. + system = System(self._sire_object) + cursor = system["not property is_perturbable"].cursor() # Rotate all vector properties. @@ -1215,8 +1216,15 @@ def rotateBoxVectors( except: pass + # Update the molecules in the system. + system.update(cursor.commit()) + self._sire_object = system._system + # Now deal with any perturbable molecules. if self.nPerturbableMolecules() > 0: + # Create a cursor for the perturbable molecules. + cursor = system["property is_perturbable"].cursor() + # Coordinates. try: prop_name = property_map.get("coordinates", "coordinates") + "0" @@ -1255,8 +1263,9 @@ def rotateBoxVectors( except: pass - # Commit the changes. - self._sire_object = cursor.commit()._system + # Update the perturbable molecules in the system. + system.update(cursor.commit()) + self._sire_object = system._system def reduceBoxVectors(self, bias=0, property_map={}): """ diff --git a/tests/FreeEnergy/test_atm.py b/tests/FreeEnergy/test_atm.py new file mode 100644 index 000000000..f2a9b0183 --- /dev/null +++ b/tests/FreeEnergy/test_atm.py @@ -0,0 +1,681 @@ +import math +import pytest +import requests +import tarfile +import tempfile +import json +import pandas as pd +import os + +import BioSimSpace as BSS + + +def test_makeSystem(TEMOA_host, TEMOA_lig1, TEMOA_lig2): + + atm_generator = BSS.FreeEnergy.ATMSetup( + receptor=TEMOA_host, ligand_bound=TEMOA_lig1, ligand_free=TEMOA_lig2 + ) + # check that an error is thrown in the rigid core atoms are not given to prepare + with pytest.raises(TypeError): + atm_system, atm_data = atm_generator.prepare() + + rigid_core = [1, 2, 3] + + atm_system, atm_data = atm_generator.prepare( + ligand_bound_rigid_core=rigid_core, ligand_free_rigid_core=rigid_core + ) + + # Check that the system contains an atm data property + data_from_system = json.loads(atm_system._sire_object.property("atom_data").value()) + to_ignore = ["displacement"] + # check that atm_data and data_from_system are the same, ignoring anything in to_ignore + assert all( + [ + data_from_system[key] == atm_data[key] + for key in atm_data + if key not in to_ignore + ] + ) + + # check that data[ligand_bound_rigid_core] and data[ligand_free_rigid_core] are the same as the input + assert data_from_system["ligand_bound_rigid_core"] == rigid_core + assert data_from_system["ligand_free_rigid_core"] == rigid_core + + # get the coordinates of the ligands + lig1_coords = atm_system[atm_data["ligand_bound_index"]]._sire_object.coordinates() + lig2_coords = atm_system[atm_data["ligand_free_index"]]._sire_object.coordinates() + # make sure the displacement is correct for the default value of 20A + assert pytest.approx((lig2_coords - lig1_coords).length().value(), rel=1) == 20.0 + + vector = BSS.Types.Vector(10.0, 10.0, 10.0) + + system_withvec, data_withvec = atm_generator.prepare( + ligand_bound_rigid_core=rigid_core, + ligand_free_rigid_core=rigid_core, + displacement=vector, + ) + + data_from_system = json.loads( + system_withvec._sire_object.property("atom_data").value() + ) + assert pytest.approx(data_from_system["displacement"], rel=1e-3) == [ + vector.x(), + vector.y(), + vector.z(), + ] + lig1_coords = system_withvec[ + data_withvec["ligand_bound_index"] + ]._sire_object.coordinates() + lig2_coords = system_withvec[ + data_withvec["ligand_free_index"] + ]._sire_object.coordinates() + + d = lig2_coords - lig1_coords + assert pytest.approx(d.x().value(), 1) == vector.x() + assert pytest.approx(d.y().value(), 1) == vector.y() + assert pytest.approx(d.z().value(), 1) == vector.z() + + # make a new atm_generator and check the parsing of a full system + atm_generator = BSS.FreeEnergy.ATMSetup(system=atm_system) + + +def test_run(TEMOA_hostguest): + system, _ = TEMOA_hostguest + production_atm = BSS.Protocol.ATMProduction( + system=system, + com_distance_restraint=True, + runtime="2 fs", + report_interval=1, + restart_interval=1, + num_lambda=2, + analysis_method="UWHAM", + ) + production_atm2 = BSS.Protocol.ATMProduction( + system=system, + com_distance_restraint=True, + runtime="4 fs", + report_interval=1, + restart_interval=1, + num_lambda=2, + analysis_method="UWHAM", + ) + with tempfile.TemporaryDirectory() as tmpdirname: + production = BSS.FreeEnergy.ATM(system, production_atm, work_dir=tmpdirname) + production.run() + production.wait() + # read openmm.csv and make sure it has a single row + df = pd.read_csv(os.path.join(tmpdirname, "lambda_0.0000/openmm.csv")) + assert len(df) == 1 + + production2 = BSS.FreeEnergy.ATM(system, production_atm2, work_dir=tmpdirname) + production2.run() + production2.wait() + df = pd.read_csv(os.path.join(tmpdirname, "lambda_0.0000/openmm.csv")) + assert len(df) == 2 + + +def test_single_point_energies(TEMOA_host, TEMOA_lig1, TEMOA_lig2): + # Tests the single point energies of the + # Mirroring inputs for G. lab code + lig1_cm_atoms_absolute = [ + 196, + 197, + 198, + 199, + 200, + 201, + 202, + 203, + 204, + 205, + 206, + 207, + 208, + 209, + 210, + 211, + 212, + 213, + 214, + 215, + 216, + ] + lig2_cm_atoms_absolute = [ + 217, + 218, + 219, + 220, + 221, + 222, + 223, + 224, + 225, + 226, + 227, + 228, + 229, + 230, + 231, + 232, + 233, + ] + lig1_cm_rel = [x - 196 for x in lig1_cm_atoms_absolute] + lig2_cm_rel = [x - 217 for x in lig2_cm_atoms_absolute] + prot_cm_atoms = [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + 58, + 59, + 60, + 61, + 62, + 63, + 64, + 65, + 66, + 67, + 68, + 69, + 70, + 71, + 72, + 73, + 74, + 75, + 76, + 77, + 78, + 79, + 80, + 81, + 82, + 83, + 84, + 85, + 86, + 87, + 88, + 89, + 90, + 91, + 92, + 93, + 94, + 95, + 96, + 97, + 98, + 99, + 100, + 101, + 102, + 103, + 104, + 105, + 106, + 107, + 108, + 109, + 110, + 111, + 112, + 113, + 114, + 115, + 116, + 117, + 118, + 119, + 120, + 121, + 122, + 123, + 124, + 125, + 126, + 127, + 128, + 129, + 130, + 131, + 132, + 133, + 134, + 135, + 136, + 137, + 138, + 139, + 140, + 141, + 142, + 143, + 144, + 145, + 146, + 147, + 148, + 149, + 150, + 151, + 152, + 153, + 154, + 155, + 156, + 157, + 158, + 159, + 160, + 161, + 162, + 163, + 164, + 165, + 166, + 167, + 168, + 169, + 170, + 171, + 172, + 173, + 174, + 175, + 176, + 177, + 178, + 179, + 180, + 181, + 182, + 183, + 184, + 185, + 186, + 187, + 188, + 189, + 190, + 191, + 192, + 193, + 194, + 195, + ] + atm_generator = BSS.FreeEnergy.ATMSetup( + receptor=TEMOA_host, ligand_bound=TEMOA_lig1, ligand_free=TEMOA_lig2 + ) + system, data = atm_generator.prepare( + displacement=[22, 22, 22], + ligand_bound_rigid_core=[8, 6, 4], + ligand_free_rigid_core=[3, 5, 1], + ligand_bound_com_atoms=lig1_cm_rel, + ligand_free_com_atoms=lig2_cm_rel, + protein_com_atoms=prot_cm_atoms, + ) + + pos_rst_atoms = [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + ] + + production_atm = BSS.Protocol.ATMProduction( + system=system, + com_distance_restraint=True, + com_k=25.0, + com_restraint_width=5.0, + restraint=pos_rst_atoms, + positional_restraint_width=0.5, + force_constant=25.0, + align_k_psi=10.0, + align_k_theta=10.0, + align_k_distance=2.5, + runtime="100 ps", + num_lambda=22, + soft_core_umax=100.0, + soft_core_a=0.0625, + soft_core_u0=50.0, + ) + + with tempfile.TemporaryDirectory() as tmpdirname: + production = BSS.Process.OpenMM( + system, + production_atm, + platform="CPU", + setup_only=True, + work_dir=tmpdirname, + **{"_is_testing": True}, + ) + production.start() + production.wait() + + assert not production.isError() + # now get the file containing single points + df = pd.read_csv(os.path.join(tmpdirname, "energies_singlepoint.csv")) + ens = df.to_dict() + + # Here we are specifically verifying the energies of the ATMForce + ens_GL = { + 0.0: 2847.6, + 0.0476: 2849.1, + 0.0952: 2850.7, + 0.1429: 2852.2, + 0.1905: 2853.7, + 0.2381: 2855.3, + 0.2857: 2856.8, + 0.3333: 2858.4, + 0.381: 2859.9, + 0.4286: 2861.5, + 0.4762: 2863.0, + } + # Need to add an offset due to treatment of 1-4 forces in GL code + offset = 803.3 + # now check that the energies are the same + for lam, en in ens_GL.items(): + assert pytest.approx(ens[str(lam)][0], rel=1) == en + offset + + # Now check the rest of the forces + df_nonlam = pd.read_csv(os.path.join(tmpdirname, "non_lambda_forces.csv")) + ens_nonlam = df_nonlam.to_dict() + + ens_GL_nolam = { + "com": 0.0, + "distance": 0.02983, + "angle": 0.0072010, + "dihedral": 3.55355e-13, + "position_restraint": 0.0, + } + for key, en in ens_GL_nolam.items(): + assert pytest.approx(ens_nonlam[key][0], rel=1e-3) == en + + +def test_UWHAM(): + import numpy as np + + # To try and ensure parity with the Gallicchio lab code + # we will test each individual element of the UWHAM calculation + potential = -69702.79 + e_pert = 70.48908 + beta = 1.678238963 + lambda1 = 0.0 + lambda2 = 0.0 + alpha = 0.0 + u0 = 0.0 + w0 = 0.0 + + n_pot = 116977.9 + + from BioSimSpace.FreeEnergy._ddg import _npot_fcn + + npot = _npot_fcn( + e0=potential, + epert=e_pert, + bet=beta, + lam1=lambda1, + lam2=lambda2, + alpha=alpha, + u0=u0, + w0=w0, + ) + + assert pytest.approx(npot, rel=1e-3) == n_pot + + # Now testing agreement with known values from + # UWHAM-R analysis + ln_q_array = np.array( + [ + [ + 117251.85638785, + 117147.70578372, + 117259.71235395, + 117184.35793014, + 116934.45115, + 117405.64541825, + 116930.39936544, + 117131.36660758, + 117072.35871073, + 117041.11910054, + 117166.97160247, + ], + [ + 117246.36847836, + 117141.83269751, + 117254.16656958, + 117181.69670683, + 116933.05433974, + 117404.90353356, + 116930.17474063, + 117130.18191686, + 117071.63726339, + 117041.08918386, + 117167.13191866, + ], + [ + 117240.88056888, + 117135.95961129, + 117248.62078521, + 117179.03548353, + 116931.65752948, + 117404.16164887, + 116929.95011581, + 117128.99722614, + 117070.91581606, + 117041.05926718, + 117167.29223485, + ], + [ + 117235.3926594, + 117130.08652508, + 117243.07500085, + 117176.37426023, + 116930.26071922, + 117403.41976418, + 116929.725491, + 117127.81253542, + 117070.19436872, + 117041.0293505, + 117167.45255104, + ], + [ + 117229.90474991, + 117124.21343886, + 117237.52921648, + 117173.71303693, + 116928.86390896, + 117402.67787949, + 116929.50086619, + 117126.6278447, + 117069.47292138, + 117040.99943382, + 117167.61286723, + ], + [ + 117224.41684043, + 117118.34035265, + 117231.98343212, + 117171.05181363, + 116927.4670987, + 117401.9359948, + 116929.27624138, + 117125.44315397, + 117068.75147404, + 117040.96951714, + 117167.77318343, + ], + [ + 117218.92893094, + 117112.46726644, + 117226.43764775, + 117168.39059033, + 116926.07028844, + 117401.1941101, + 116929.05161657, + 117124.25846325, + 117068.0300267, + 117040.93960046, + 117167.93349962, + ], + [ + 117213.44102146, + 117106.59418022, + 117220.89186339, + 117165.72936702, + 116924.67347817, + 117400.45222541, + 116928.82699176, + 117123.07377253, + 117067.30857936, + 117040.90968378, + 117168.09381581, + ], + [ + 117207.95311198, + 117100.72109401, + 117215.34607902, + 117163.06814372, + 116923.27666791, + 117399.71034072, + 116928.60236695, + 117121.88908181, + 117066.58713202, + 117040.8797671, + 117168.254132, + ], + [ + 117202.46520249, + 117094.84800779, + 117209.80029465, + 117160.40692042, + 116921.87985765, + 117398.96845603, + 116928.37774214, + 117120.70439109, + 117065.86568468, + 117040.84985042, + 117168.41444819, + ], + [ + 117196.97729301, + 117088.97492158, + 117204.25451029, + 117157.74569712, + 116920.48304739, + 117398.22657134, + 116928.15311733, + 117119.51970037, + 117065.14423734, + 117040.81993374, + 117168.57476438, + ], + ] + ) + + n_samples = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] + + known_answer = 12.42262 + known_error = 1.422156 + + from BioSimSpace.FreeEnergy._ddg import _estimate_f_i + + f_i, d_i, weights = _estimate_f_i(ln_q_array, n_samples) + ddg = f_i[-1] - f_i[0] + ddg = ddg / beta + ddg_error = np.sqrt(d_i[-1] + d_i[0]) / beta + + assert pytest.approx(ddg, rel=1e-3) == known_answer + assert pytest.approx(ddg_error, rel=1e-3) == known_error diff --git a/tests/Process/test_atm.py b/tests/Process/test_atm.py new file mode 100644 index 000000000..151182c54 --- /dev/null +++ b/tests/Process/test_atm.py @@ -0,0 +1,126 @@ +import pytest + +import BioSimSpace as BSS + + +def test_atm_minimisation(TEMOA_hostguest): + # First get a system with data + system, data = TEMOA_hostguest + # Generate a minimisation protocol + prot_min = BSS.Protocol.ATMMinimisation(data=data, steps=1) + + run_process(system, prot_min) + del system, data + + +@pytest.mark.parametrize("use_atm_force", [True, False]) +def test_atm_equilibration(TEMOA_hostguest, use_atm_force): + # First get a system with data + system, data = TEMOA_hostguest + # Generate an equilibration protocol + prot_equil = BSS.Protocol.ATMEquilibration( + data=data, + runtime="4 fs", + use_atm_force=use_atm_force, + report_interval=1, + restart_interval=1, + ) + + run_process(system, prot_equil) + del system, data + + +def test_atm_anneal(TEMOA_hostguest): + # First get a system with data + system, data = TEMOA_hostguest + # Generate an annealing protocol + prot_anneal = BSS.Protocol.ATMAnnealing( + data=data, + runtime="4 fs", + report_interval=1, + restart_interval=1, + anneal_numcycles=1, + ) + + run_process(system, prot_anneal) + del system, data + + +def test_custom_atm_anneal(TEMOA_hostguest): + # First get a system with data + system, data = TEMOA_hostguest + # now test passing a valid dictionary + annealing_dict = { + "lambda1_start": 0.0, + "lambda1_end": 0.5, + "lambda2_start": 0.0, + "lambda2_end": 0.5, + "alpha_start": 0.0, + "alpha_end": 0.5, + "uh_start": 0.0, + "uh_end": 0.5, + "W0_start": 0.0, + "W0_end": 0.5, + } + protocol = BSS.Protocol.ATMAnnealing( + data=data, + anneal_values=annealing_dict, + anneal_numcycles=1, + runtime="2 fs", + report_interval=1, + restart_interval=1, + ) + run_process(system, protocol) + + +def test_atm_production(TEMOA_hostguest): + # First get a system with data + system, data = TEMOA_hostguest + # Generate a production protocol + prot_prod = BSS.Protocol.ATMProduction( + data=data, + runtime="2 fs", + report_interval=1, + restart_interval=1, + ) + + run_process(system, prot_prod) + + # now test "MBAR" analysis method + prot_prod = BSS.Protocol.ATMProduction( + data=data, + runtime="2 fs", + analysis_method="MBAR", + report_interval=1, + restart_interval=1, + ) + run_process(system, prot_prod) + + # finally, test the "both" analysis method + prot_prod = BSS.Protocol.ATMProduction( + data=data, + runtime="2 fs", + analysis_method="both", + report_interval=1, + restart_interval=1, + ) + run_process(system, prot_prod) + + +def run_process(system, protocol): + """Helper function to run various simulation protocols.""" + + # Initialise the OpenMM process. + process = BSS.Process.OpenMM(system, protocol, name="test") + + # Start the OpenMM simulation. + process.start() + + # Wait for the process to end. + process.wait() + + # Make sure the process didn't error. + assert not process.isError() + + # Make sure that we get a molecular system back. + assert process.getSystem() is not None diff --git a/tests/Protocol/test_atm.py b/tests/Protocol/test_atm.py new file mode 100644 index 000000000..c2f066bc4 --- /dev/null +++ b/tests/Protocol/test_atm.py @@ -0,0 +1,285 @@ +import pytest +import BioSimSpace as BSS + + +def test_atm_minimisation(TEMOA_hostguest): + # We will use this as a test for all of the parent class inputs + + # First need to test that both forms of the `data` input work + system, data = TEMOA_hostguest + BSS.Protocol.ATMMinimisation(data=data) + BSS.Protocol.ATMMinimisation(system=system) + # Now test the optional inputs, first using biosimspace Units + protocol_units = BSS.Protocol.ATMMinimisation( + data=data, + core_alignment=False, + com_distance_restraint=False, + restraint="all", + force_constant=1.0 * (BSS.Units.Energy.kcal_per_mol / BSS.Units.Area.angstrom2), + positional_restraint_width=0.1 * BSS.Units.Length.angstrom, + align_k_distance=1.0 * BSS.Units.Energy.kcal_per_mol / BSS.Units.Area.angstrom2, + align_k_theta=1.0 * BSS.Units.Energy.kcal_per_mol, + align_k_psi=1.0 * BSS.Units.Energy.kcal_per_mol, + soft_core_umax=10.0 * BSS.Units.Energy.kcal_per_mol, + soft_core_u0=1.0 * BSS.Units.Energy.kcal_per_mol, + soft_core_a=0.01, + com_k=1.0 * BSS.Units.Energy.kcal_per_mol / BSS.Units.Area.angstrom2, + com_restraint_width=1.0 * BSS.Units.Length.angstrom, + ) + + # Now test parsing options as floats + protocol_floats = BSS.Protocol.ATMMinimisation( + data=data, + force_constant=1.0, + positional_restraint_width=0.1, + align_k_distance=1.0, + align_k_theta=1.0, + align_k_psi=1.0, + soft_core_umax=10.0, + soft_core_u0=1.0, + soft_core_a=0.01, + com_k=1.0, + com_restraint_width=1.0, + ) + + # Finally try parsing strings + protocol_strings = BSS.Protocol.ATMMinimisation( + data=data, + force_constant="1.0 kcal mol^-1 angstrom^-2", + positional_restraint_width="0.1 angstrom", + align_k_distance="1.0 kcal mol^-1 angstrom^-2", + align_k_theta="1.0 kcal mol^-1", + align_k_psi="1.0 kcal mol^-1", + soft_core_umax="10.0 kcal mol^-1", + soft_core_u0="1.0 kcal mol^-1", + soft_core_a=0.01, + com_k="1.0 kcal mol^-1 angstrom^-2", + com_restraint_width="1.0 angstrom", + ) + # using getters, check that all protocols have the same values + # (skip force constant and as it is not atm exclusive) + assert ( + protocol_units.getPosRestWidth() + == protocol_floats.getPosRestWidth() + == protocol_strings.getPosRestWidth() + ) + assert ( + protocol_units.getAlignKDistance() + == protocol_floats.getAlignKDistance() + == protocol_strings.getAlignKDistance() + ) + assert ( + protocol_units.getAlignKTheta() + == protocol_floats.getAlignKTheta() + == protocol_strings.getAlignKTheta() + ) + assert ( + protocol_units.getAlignKPsi() + == protocol_floats.getAlignKPsi() + == protocol_strings.getAlignKPsi() + ) + assert ( + protocol_units.getSoftCoreUmax() + == protocol_floats.getSoftCoreUmax() + == protocol_strings.getSoftCoreUmax() + ) + assert ( + protocol_units.getSoftCoreU0() + == protocol_floats.getSoftCoreU0() + == protocol_strings.getSoftCoreU0() + ) + assert ( + protocol_units.getSoftCoreA() + == protocol_floats.getSoftCoreA() + == protocol_strings.getSoftCoreA() + ) + assert ( + protocol_units.getCOMk() + == protocol_floats.getCOMk() + == protocol_strings.getCOMk() + ) + assert ( + protocol_units.getCOMWidth() + == protocol_floats.getCOMWidth() + == protocol_strings.getCOMWidth() + ) + + +def test_atm_equilibration(TEMOA_hostguest): + # Testing equilibration-specific inputs + system, data = TEMOA_hostguest + + protocol_units = BSS.Protocol.ATMEquilibration( + data=data, + timestep=1 * BSS.Units.Time.femtosecond, + runtime=0.1 * BSS.Units.Time.nanosecond, + temperature_start=200 * BSS.Units.Temperature.kelvin, + temperature_end=300 * BSS.Units.Temperature.kelvin, + pressure=0.99 * BSS.Units.Pressure.atm, + thermostat_time_constant=1.5 * BSS.Units.Time.picosecond, + report_interval=1000, + restart_interval=1001, + use_atm_force=True, + direction=-1, + lambda1=0.1, + lambda2=0.2, + alpha=0.1 * BSS.Units.Energy.kcal_per_mol, + uh=0.1 * BSS.Units.Energy.kcal_per_mol, + W0=0.1 * BSS.Units.Energy.kcal_per_mol, + ) + + # test setting alpha,uh and w0 as floats + protocol_floats = BSS.Protocol.ATMEquilibration( + data=data, + alpha=0.1, + uh=0.1, + W0=0.1, + ) + + # test setting alpha,uh and w0 as strings + protocol_strings = BSS.Protocol.ATMEquilibration( + data=data, + timestep="1 fs", + runtime="0.1 ns", + temperature_start="200 K", + temperature_end="300 K", + pressure="0.99 atm", + thermostat_time_constant="1.5 ps", + alpha="0.1 kcal mol^-1", + uh="0.1 kcal mol^-1", + W0="0.1 kcal mol^-1", + ) + + # Check that all protocols have the same values + assert protocol_units.getTimeStep() == protocol_strings.getTimeStep() + assert protocol_units.getRunTime() == protocol_strings.getRunTime() + assert ( + protocol_units.getStartTemperature() == protocol_strings.getStartTemperature() + ) + assert protocol_units.getEndTemperature() == protocol_strings.getEndTemperature() + assert protocol_units.getPressure() == protocol_strings.getPressure() + assert ( + protocol_units.getThermostatTimeConstant() + == protocol_strings.getThermostatTimeConstant() + ) + assert ( + protocol_units.getAlpha() + == protocol_floats.getAlpha() + == protocol_strings.getAlpha() + ) + assert protocol_units.getUh() == protocol_floats.getUh() == protocol_strings.getUh() + assert protocol_units.getW0() == protocol_floats.getW0() == protocol_strings.getW0() + + +def test_atm_annealing(TEMOA_hostguest): + # Testing annealing-specific inputs + system, data = TEMOA_hostguest + + # first test passing an invalid key in the annealing dictionary + annealing_dict = { + "lambda1_start": 0.0, + "lambda1_end": 0.5, + "lambda2_start": 0.0, + "lambda2_end": 0.5, + "invalid_key": 0.0, + } + with pytest.raises(ValueError): + BSS.Protocol.ATMAnnealing(data=data, anneal_values=annealing_dict) + + # now test passing a valid dictionary + annealing_dict = { + "lambda1_start": 0.0, + "lambda1_end": 0.5, + "lambda2_start": 0.0, + "lambda2_end": 0.5, + "alpha_start": 0.0, + "alpha_end": 0.5, + "uh_start": 0.0, + "uh_end": 0.5, + "W0_start": 0.0, + "W0_end": 0.5, + } + protocol = BSS.Protocol.ATMAnnealing( + data=data, + anneal_values=annealing_dict, + anneal_numcycles=10, + alpha="0.1 kcal mol^-1", + uh="0.1 kcal mol^-1", + W0="0.1 kcal mol^-1", + ) + + +def test_atm_production(TEMOA_hostguest): + # Testing production-specific inputs + system, data = TEMOA_hostguest + + # fist create a production protocol with num_lambda=6 + protocol = BSS.Protocol.ATMProduction( + data=data, + num_lambda=6, + ) + # get values for direction, lambda1, lambda2, alpha, uh and W0 + assert len(protocol.getDirection()) == 6 + assert len(protocol.getLambda1()) == 6 + assert len(protocol.getLambda2()) == 6 + assert len(protocol.getAlpha()) == 6 + assert len(protocol.getUh()) == 6 + assert len(protocol.getW0()) == 6 + + # Define custom values for direction that are not valid and check that an error is raised + d = [1, 2, 3, 4, 5, 6] + with pytest.raises(ValueError): + protocol.setDirection(d) + + # Define custom values for lambda1 that are not valid and check that an error is raised + l1 = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6] + with pytest.raises(ValueError): + protocol.setLambda1(l1) + + # Define custom values for lambda2 that are not valid and check that an error is raised + l2 = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6] + with pytest.raises(ValueError): + protocol.setLambda2(l2) + + # check that a list of strings, ints, floats and BSS units can be parsed for alpha,uh and w0 + list_of_units = [ + "0.1 kcal mol^-1", + 0.1, + 0.1 * BSS.Units.Energy.kcal_per_mol, + 0.1 * BSS.Units.Energy.kcal_per_mol, + 0.1, + 1, + ] + + protocol = BSS.Protocol.ATMProduction( + data=data, + num_lambda=6, + alpha=list_of_units, + uh=list_of_units, + W0=list_of_units, + ) + + end_product = [0.1 * BSS.Units.Energy.kcal_per_mol] * 5 + end_product.append(1 * BSS.Units.Energy.kcal_per_mol) + assert protocol.getAlpha() == end_product + assert protocol.getUh() == end_product + assert protocol.getW0() == end_product + + # now check that all of the allowed analysis options can be set + protocol = BSS.Protocol.ATMProduction( + data=data, + num_lambda=6, + analysis_method="UWHAM", + ) + + protocol = BSS.Protocol.ATMProduction( + data=data, + num_lambda=6, + analysis_method="MBAR", + ) + + protocol = BSS.Protocol.ATMProduction( + data=data, + num_lambda=6, + analysis_method="both", + ) diff --git a/tests/conftest.py b/tests/conftest.py index 46be2d10a..04106d872 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -86,3 +86,39 @@ def solvated_perturbable_system(): f"{url}/solvated_perturbable_system1.prm7", f"{url}/solvated_perturbable_system1.rst7", ) + + +@pytest.fixture(scope="session") +def TEMOA_host(): + host = BSS.IO.readMolecules( + BSS.IO.expand(BSS.tutorialUrl(), ["temoa_host.rst7", "temoa_host.prm7"]) + )[0] + return host + + +@pytest.fixture(scope="session") +def TEMOA_lig1(): + lig1 = BSS.IO.readMolecules( + BSS.IO.expand(BSS.tutorialUrl(), ["temoa_ligG1.rst7", "temoa_ligG1.prm7"]) + )[0] + return lig1 + + +@pytest.fixture(scope="session") +def TEMOA_lig2(): + lig2 = BSS.IO.readMolecules( + BSS.IO.expand(BSS.tutorialUrl(), ["temoa_ligG4.rst7", "temoa_ligG4.prm7"]) + )[0] + return lig2 + + +@pytest.fixture(scope="session") +def TEMOA_hostguest(TEMOA_host, TEMOA_lig1, TEMOA_lig2): + atm_generator = BSS.FreeEnergy.ATMSetup( + receptor=TEMOA_host, ligand_bound=TEMOA_lig1, ligand_free=TEMOA_lig2 + ) + rigid_core = [1, 2, 3] + atm_system, atm_data = atm_generator.prepare( + ligand_bound_rigid_core=rigid_core, ligand_free_rigid_core=rigid_core + ) + return atm_system, atm_data