Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Using the robust solver for pyMBAR - avoiding convergence Failu… #735

Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ jobs:
- name: Setup micromamba
uses: mamba-org/setup-micromamba@v1
with:
micromamba-version: '2.0.0-0'
environment-file: devtools/conda-envs/test_env.yaml
environment-name: openmmtools-test
create-args: >-
Expand Down
7 changes: 7 additions & 0 deletions openmmtools/multistate/multistateanalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import copy
import inspect
import logging
from packaging.version import Version
import re
from typing import Optional, NamedTuple, Union

Expand All @@ -37,6 +38,7 @@
import simtk.unit as units
from scipy.special import logsumexp

from openmmtools.multistate import pymbar
from openmmtools import multistate, utils, forces
from openmmtools.multistate.pymbar import (
statistical_inefficiency_multiple,
Expand Down Expand Up @@ -567,6 +569,11 @@ def __init__(self, reporter, name=None, reference_states=(0, -1),
self.reference_states = reference_states
self._user_extra_analysis_kwargs = analysis_kwargs # Store the user-specified (higher priority) keywords

# If we are using pymbar 4, change the default behavior to use the robust solver protocol if the user
# didn't set a kwarg to control the solver protocol
if Version(pymbar.__version__) >= Version("4") and "solver_protocol" not in self._user_extra_analysis_kwargs:
self._user_extra_analysis_kwargs["solver_protocol"] = "robust"

# Initialize cached values that are read or derived from the Reporter.
self._cache = {} # This cache should be always set with _update_cache().
self.clear()
Expand Down
3 changes: 2 additions & 1 deletion openmmtools/multistate/pymbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
subsample_correlated_data,
statistical_inefficiency
)
from pymbar import MBAR
from pymbar import MBAR, __version__
from pymbar.utils import ParameterError
except ImportError:
# pymbar < 4
Expand All @@ -22,6 +22,7 @@
)
from pymbar import MBAR
from pymbar.utils import ParameterError
from pymbar.version import short_version as __version__


def _pymbar_bar(
Expand Down
41 changes: 40 additions & 1 deletion openmmtools/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
import shutil
import sys
import tempfile
import time
from io import StringIO

import numpy as np
import yaml

import pytest
import requests

try:
import openmm
Expand Down Expand Up @@ -306,6 +308,7 @@ def run(self, include_unsampled_states=False):
# Clean up.
del simulation

@pytest.mark.flaky(reruns=3)
def test_with_unsampled_states(self):
"""Test multistate sampler on a harmonic oscillator with unsampled endstates"""
self.run(include_unsampled_states=True)
Expand Down Expand Up @@ -1861,7 +1864,7 @@ def test_analysis_opens_without_checkpoint(self):
del reporter
self.REPORTER(storage_path, checkpoint_storage=cp_file_mod, open_mode="r")

@pytest.mark.flaky(reruns=3)
@pytest.mark.skipif(sys.platform == "darwin", reason="seg faults on osx sometimes")
def test_storage_reporter_and_string(self):
"""Test that creating a MultiState by storage string and reporter is the same"""
thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy(
Expand Down Expand Up @@ -2612,6 +2615,42 @@ def test_resume_velocities_from_legacy_storage(self):
state.velocities.value_in_unit_system(unit.md_unit_system) != 0
), "At least some velocity in sampler state from new checkpoint is expected to different from zero."

@pytest.fixture
def download_nc_file(tmpdir):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would using something like pooch be better for this kind of thing?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about that but it felt like a lot to add for a single test, if we have more to download then we can add it.

FILE_URL = "https://github.com/user-attachments/files/17156868/ala-thr.zip"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we cannot rely on this long term, since as far as I know the structure of these URLs can change without notice. Is there another way we can have a long-lived url? Not a big issue but probably something to keep in mind.

MAX_RETRIES = 3
RETRY_DELAY = 2 # Delay between retries (in seconds)
file_name = os.path.join(tmpdir, "ala-thr.nc")
retries = 0
while retries < MAX_RETRIES:
try:
# Send GET request to download the file
response = requests.get(FILE_URL, timeout=20) # Timeout to avoid hanging
response.raise_for_status() # Raise HTTPError for bad responses (4xx/5xx)
with open(file_name, "wb") as f:
f.write(response.content)
# File downloaded successfully, break out of retry loop
break

except (requests.exceptions.RequestException, requests.exceptions.HTTPError) as e:
retries += 1
if retries >= MAX_RETRIES:
pytest.fail(f"Failed to download file after {MAX_RETRIES} retries: {e}")
else:
print(f"Retrying download... ({retries}/{MAX_RETRIES})")
time.sleep(RETRY_DELAY) # Wait before retrying
yield file_name


def test_pymbar_issue_419(download_nc_file):
from openmmtools.multistate import MultiStateReporter, MultiStateSamplerAnalyzer

n_iterations = 1000
reporter_file = download_nc_file
reporter = MultiStateReporter(reporter_file)
analyzer = MultiStateSamplerAnalyzer(reporter, max_n_iterations=n_iterations)
f_ij, df_ij = analyzer.get_free_energy()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would encourage doing a number regression check here rather than just a pure smoke test.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about that but we already do that in other tests, happy to add it. What threshold do we want to use to check that it is close?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be sampling anything, so ideally we should always be converging to a very similar value. If it changes, something changed enough to affect all our free energies. I would suggest something like 1e-4 or 1e-5 precision.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll be honest, I forgot we were not sampling something, so I didn't even consider that we will get the same result each time we run (or really close to the same result)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably need a docstring here telling what the test is doing and how the nc file was generated. For future reference.



# ==============================================================================
# MAIN AND TESTS
Expand Down
Loading