Skip to content

Commit

Permalink
coherent state handling in openmm
Browse files Browse the repository at this point in the history
  • Loading branch information
axsk committed May 13, 2024
1 parent cc2fe76 commit 5e865d9
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 39 deletions.
69 changes: 42 additions & 27 deletions src/simulators/mopenmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,16 @@
def threadedrun(xs, sim, stepsize, steps, nthreads, nthreadssim=1, withmomenta=False):
def singlerun(i):
c = newcontext(sim.context, nthreadssim)
if withmomenta:
n = len(xs[i]) // 2
c.setPositions(xs[i][:n])
c.setVelocities(xs[i][n:])
else:
c.setPositions(xs[i])
c.setVelocitiesToTemperature(sim.integrator.getTemperature())

set_numpy_state(c, xs[i], withmomenta)
c.getIntegrator().setStepSize(stepsize)

try:
c.getIntegrator().step(steps)

except OpenMMException as e:
print("Error integrating trajectory", e)
x = c.getState(getPositions=True).getPositions(asNumpy=True).value_in_unit(nanometer)
x.fill(np.nan)
return x

if withmomenta:
state = c.getState(getPositions=True, getVelocities=True)
x = np.concatenate([
state.getPositions(asNumpy=True).value_in_unit(nanometer),
state.getVelocities(asNumpy=True).value_in_unit(nanometer/picosecond)])
else:
x = c.getState(getPositions=True).getPositions(asNumpy=True).value_in_unit(nanometer)

return get_numpy_state(c, withmomenta).fill(np.nan)

x = get_numpy_state(c, withmomenta)
return x


Expand All @@ -42,19 +26,26 @@ def singlerun(i):
out = [singlerun(i) for i in range(len(xs))]
return np.array(out).flatten()

def trajectory(sim, x0, stepsize, steps, saveevery, mmthreads):
def trajectory(sim, x0, stepsize, steps, saveevery, mmthreads, withmomenta):
n_states = steps // saveevery + 1
trajectory = np.zeros((n_states,) + np.array(x0).shape)
trajectory[0] = x0

c = newcontext(sim.context, mmthreads)
c.setPositions(x0)
c.setVelocitiesToTemperature(c.getIntegrator().getTemperature())

if withmomenta:
n = len(x0) // 2
c.setPositions(x0[:n])
c.setVelocities(x0[n:])
else:
c.setPositions(x0)
c.setVelocitiesToTemperature(sim.integrator.getTemperature())

c.getIntegrator().setStepSize(stepsize)

for n in range(1,n_states):
c.getIntegrator().step(saveevery)
trajectory[n] = get_numpy_pos(c)
trajectory[n] = get_numpy_state(c, withmomenta)

return trajectory

Expand Down Expand Up @@ -106,12 +97,36 @@ def defaultsystem(pdb, ligand, forcefields, temp, friction, step, minimize, plat

simulation.context.setPositions(modeller.positions)
simulation.context.setVelocitiesToTemperature(simulation.integrator.getTemperature())

simulation.reporters.append(
StateDataReporter(
"openmmsimulation.log", 1, step=True,
potentialEnergy=True, totalEnergy=True,
temperature=True, speed=True,)
)

if minimize:
simulation.minimizeEnergy(maxIterations=100)
return simulation

def get_numpy_pos(context):
return context.getState(getPositions=True).getPositions(asNumpy=True).value_in_unit(nanometer)
def get_numpy_state(context, withmomenta):
if withmomenta:
state = context.getState(getPositions=True, getVelocities=True)
x = np.concatenate([
state.getPositions(asNumpy=True).value_in_unit(nanometer),
state.getVelocities(asNumpy=True).value_in_unit(nanometer/picosecond)])
else:
x = context.getState(getPositions=True).getPositions(asNumpy=True).value_in_unit(nanometer)
return x

def set_numpy_state(context, x, withmomenta):
if withmomenta:
n = len(x) // 2
context.setPositions(x[:n])
context.setVelocities(x[n:])
else:
context.setPositions(x)
context.setVelocitiesToTemperature(sim.integrator.getTemperature())

def newcontext(context, mmthreads):
if mmthreads == 'gpu':
Expand Down
16 changes: 4 additions & 12 deletions src/simulators/openmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,9 @@ end
Return the coordinates of a single trajectory started at `x0` for the given number of `steps` where each `saveevery` step is stored.
"""
function trajectory(s::OpenMMSimulation, x0::AbstractVector{T}, steps=s.steps, saveevery=1; stepsize=s.step, mmthreads=s.mmthreads) where {T}
function trajectory(s::OpenMMSimulation, x0::AbstractVector{T}=getcoords(s), steps=s.steps, saveevery=1; stepsize=s.step, mmthreads=s.mmthreads, momenta=s.momenta) where {T}
x0 = reinterpret(Tuple{T,T,T}, x0)
xs = py"trajectory"(s.pysim, x0, stepsize, steps, saveevery, mmthreads)
xs = py"trajectory"(s.pysim, x0, stepsize, steps, saveevery, mmthreads, momenta)
xs = permutedims(xs, (3, 2, 1))
xs = reshape(xs, :, size(xs, 3))
return xs
Expand All @@ -185,16 +185,8 @@ end
getcoords(sim::OpenMMSimulation) = getcoords(sim.pysim, sim.momenta)#::Vector
setcoords(sim::OpenMMSimulation, coords) = setcoords(sim.pysim, coords, sim.momenta)

function getcoords(sim::PyObject, momenta)
st = sim.context.getState(getPositions=true, getVelocities=momenta)
x = st.getPositions(asNumpy=true).flatten()
if !momenta
return x
else
v = st.getVelocities(asNumpy=true).flatten()
return vcat(x, v)
end
end
getcoords(sim::PyObject, momenta) = py"get_numpy_state($sim.context, $momenta).flatten()"


function setcoords(sim::PyObject, coords::AbstractVector{T}, momenta) where {T}
t = reinterpret(Tuple{T,T,T}, Array(coords))
Expand Down

0 comments on commit 5e865d9

Please sign in to comment.