Skip to content

Commit

Permalink
include the dashboard into the package, convenience trajectorydata ge…
Browse files Browse the repository at this point in the history
…nerators
  • Loading branch information
axsk committed Jun 27, 2024
1 parent dd17b35 commit 6a9cabe
Show file tree
Hide file tree
Showing 9 changed files with 48 additions and 35 deletions.
3 changes: 3 additions & 0 deletions src/ISOKANN.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,7 @@ include("reactionpath2.jl")
include("IsoMu/IsoMu.jl")
include("vgv/vgv.jl")

include("makie.jl")
include("bonito.jl")

end
19 changes: 8 additions & 11 deletions src/bonito.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
using Revise
using Bonito
using WGLMakie
using ISOKANN

includet("makie.jl")

#iso = Iso2(OpenMMSimulation(steps=30, pdb="data/vgv.pdb", forcefields=ISOKANN.OpenMM.FORCE_AMBER_IMPLICIT), loggers=[], opt=ISOKANN.NesterovRegularized(1e-3, 0), nx=10, nk=1, minibatch=64)

USEGPU = true
ISO = nothing
Expand Down Expand Up @@ -111,10 +105,13 @@ function isocreator()
); width="300px",), isoo
end

app = App(title="ISOKANN Dashboard") do session
return content(session)
end

function serve()
app = App(title="ISOKANN Dashboard") do session
return content(session)
end

server = Bonito.get_server()
route!(server, "/" => app)
server = Bonito.get_server()
route!(server, "/" => app)
return app
end
10 changes: 0 additions & 10 deletions src/data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,16 +140,6 @@ end

@deprecate shuffledata(data) shuffleobs(data)

# TODO: this does not belong here!
""" trajectory(sim, nx)
generate a trajectory of length `nx` from the simulation `sim`"""
function trajectory(sim, nx)
siml = deepcopy(sim)
logevery = round(Int, sim.T / sim.dt)
siml.T = sim.T * nx
xs = solve(siml; logevery=logevery)
return xs
end


### Data I/O
Expand Down
5 changes: 2 additions & 3 deletions src/makie.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using WGLMakie
using ISOKANN
using WGLMakie.Observables: throttle
using ThreadPools

Expand Down Expand Up @@ -71,7 +70,7 @@ end



function dashboard(iso::Iso2, session)
function dashboard(iso::Iso2, session=nothing)
coords = Observable(iso.data.coords[1] |> cpu)
chis = Observable(ISOKANN.chis(iso) |> vec |> cpu)
icur = Observable(1)
Expand Down Expand Up @@ -144,7 +143,7 @@ function dashboard(iso::Iso2, session)
ThreadPools.@tspawnat 1 begin
try
last = time()
while isready(session) && run.active[]
while (isnothing(session) || isready(session)) && run.active[]

run!(iso)
adaptivesampling.active[] && ISOKANN.resample_kde!(iso, 1; padding=0.01, bandwidth=0.1)
Expand Down
3 changes: 2 additions & 1 deletion src/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,12 @@ end
`features` allows to pass a featurizer as preprocessor,
`activation` determines the activation function for each but the last layer
`lastactivation` can be used to modify the last layers activation function """
function pairnet(n::Int=22; layers=3, features=identity, activation=Flux.sigmoid, lastactivation=identity, nout=1)
function pairnet(n::Int=22; layers=3, features=identity, activation=Flux.sigmoid, lastactivation=identity, nout=1, layernorm=false)
float32(x) = Float32.(x)
nn = Flux.Chain(
#float32,
features,
layernorm ? Flux.LayerNorm(n) : identity,
[Flux.Dense(
round(Int, n^(l / layers)),
round(Int, n^((l - 1) / layers)),
Expand Down
14 changes: 14 additions & 0 deletions src/simulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ function randx0(sim::IsoSimulation, nx)
return xs
end

trajectory(sim::IsoSimulation, steps) = error("not implemented")
laggedtrajectory(sim::IsoSimulation, nx) = error("not implemented")

###

"""
Expand Down Expand Up @@ -191,3 +194,14 @@ end
function datasize((xs, ys)::Tuple)
return size(xs), size(ys)
end

function trajectorydata(sim::IsoSimulation, steps; reverse=false, kwargs...)
xs = laggedtrajectory(sim, steps)
SimulationData(sim, data_from_trajectory(xs; reverse), kwargs...)
end

function trajectoryburstdata(sim, steps, nk; kwargs)
xs = laggedtrajectory(sim, steps)
ys = propagate(sim, xs, nk)
SimulationData(sim, ys, kwargs...)
end
10 changes: 10 additions & 0 deletions src/simulators/langevin.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@ function propagate(l::AbstractLangevin, x0::AbstractMatrix, ny)
return ys
end

""" trajectory(sim, nx)
generate a trajectory of length `nx` from the simulation `sim`"""
function trajectory(sim, nx)
siml = deepcopy(sim)
logevery = round(Int, sim.T / sim.dt)
siml.T = sim.T * nx
xs = solve(siml; logevery=logevery)
return xs
end

function solve_end(l::AbstractLangevin; u0)
StochasticDiffEq.solve(SDEProblem(l, u0))[:, end]
end
Expand Down
11 changes: 2 additions & 9 deletions src/simulators/mopenmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,8 @@ def trajectory(sim, x0, stepsize, steps, saveevery, mmthreads, withmomenta):
trajectory[0] = x0

c = newcontext(sim.context, mmthreads)

if withmomenta:
n = len(x0) // 2
c.setPositions(x0[:n])
c.setVelocities(x0[n:])
else:
c.setPositions(x0)
c.setVelocitiesToTemperature(sim.integrator.getTemperature())


set_numpy_state(c, x0)
c.getIntegrator().setStepSize(stepsize)

for n in range(1,n_states):
Expand Down
8 changes: 7 additions & 1 deletion src/simulators/openmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import ISOKANN: ISOKANN, IsoSimulation,
propagate, dim, randx0,
featurizer, defaultmodel,
savecoords, getcoords, force, pdb,
force, potential, lagtime
force, potential, lagtime, trajectory

export OpenMMSimulation, FORCE_AMBER, FORCE_AMBER_IMPLICIT

Expand Down Expand Up @@ -193,6 +193,12 @@ function trajectory(s::OpenMMSimulation, x0::AbstractVector{T}=getcoords(s), ste
return xs
end

function ISOKANN.laggedtrajectory(s::OpenMMSimulation, nlags)
steps = s.steps * nlags
saveevery = s.steps
trajectory(s, getcoords(s), steps, saveevery)
end

getcoords(sim::OpenMMSimulation) = getcoords(sim.pysim, sim.momenta)#::Vector
setcoords(sim::OpenMMSimulation, coords) = setcoords(sim.pysim, coords, sim.momenta)

Expand Down

0 comments on commit 6a9cabe

Please sign in to comment.