Skip to content

Commit

Permalink
prototype ISO2 interface
Browse files Browse the repository at this point in the history
  • Loading branch information
axsk committed Feb 6, 2024
1 parent 043b299 commit cfbf8ea
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 24 deletions.
78 changes: 76 additions & 2 deletions scripts/adaptivesampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ if false # workaround vscode "missing reference bug"
using .ISOKANN: randx0, propagate, TransformShiftscale, TransformISA, adddata, isosteps, defaultmodel, OpenMMSimulation
end

using ISOKANN: randx0, propagate, TransformShiftscale, TransformISA, adddata, isosteps, defaultmodel, OpenMMSimulation
using ISOKANN: ISOKANN, randx0, propagate, TransformShiftscale, TransformISA, adddata, isosteps, defaultmodel, OpenMMSimulation, scatter_ramachandran, scatter_chifix, plot_chi

function adapt_setup(;
steps=100,
Expand Down Expand Up @@ -60,5 +60,79 @@ function gpu(iso::NamedTuple)
ys = Flux.gpu(ys)

iso = (; xs, ys, model, opt, sim, transform, losses, targets)
end

@kwdef mutable struct ISO2
sim
transform
model
opt

nepochs::Int
nresample::Int
npower::Int
nupdate::Int

xs
ys
losses
targets
end


# The new IsoRun()
# missing features:
# live visualization / loggers
# minibatch

function IsoRun2(;
sim=OpenMMSimulation(),
nchi=1,
transform=(nchi == 1 ? TransformShiftscale() : TransformISA()),
nlayers=4,
activation=Flux.relu,
model=defaultmodel(sim; nout=nchi, activation, layers=nlayers),
lr=1e-4,
decay=1e-5,
opt=Flux.setup(Flux.AdamW(lr, (0.9, 0.999), decay), model),
nx0=10,
nmc=10,
nepochs=1,
nresample=0,
npower=1,
nupdate=1,
xs=randx0(sim, nx0),
ys=propagate(sim, xs, nmc),
losses=Float64[],
targets=Matrix{Float64}[],
)
ISO2(; sim, transform, model, opt,
nepochs, nresample, npower, nupdate,
xs, ys, losses, targets)
end

function ISOKANN.run!(iso::ISO2)
(; sim, transform, model, opt,
nepochs, nresample, npower, nupdate,
losses, targets) = iso

for _ in 1:nepochs
@time "resampling" iso.xs, iso.ys = adddata((iso.xs, iso.ys), model, sim, nresample,)
@time "training" isosteps(model, opt, (iso.xs, iso.ys), npower, nupdate; transform, losses, targets)
end

return iso
end

using Plots

function plot_learning(iso::ISO2)
(; losses, xs, model) = iso
p1 = plot(losses, yaxis=:log, title="loss", label="trainloss", xlabel="iter")
p2 = plot_chi(xs, vec(model(xs)))
p3 = scatter_chifix((iso.xs, iso.ys), model)
ps = [p1, p2, p3]
plot(ps..., layout=(length(ps), 1), size=(400, 300 * length(ps)))
end


end
3 changes: 2 additions & 1 deletion src/ISOKANN.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ include("isomolly.jl") # ISOKANN for Molly systems
include("plots.jl") # visualizations
include("loggers.jl") # performance metric loggers
include("benchmarks.jl") # benchmark runs, deprecated by scripts/*
include("cuda.jl") # fixes for cuda
include("reactionpath.jl")

include("isosimple.jl")
Expand All @@ -65,6 +64,8 @@ include("openmm.jl")
import .OpenMM.OpenMMSimulation
export OpenMMSimulation

include("cuda.jl") # fixes for cuda

#include("dataloader.jl")

#include("precompile.jl") # precompile for faster ttx
Expand Down
2 changes: 1 addition & 1 deletion src/forced/langevin.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,4 @@ triplewell(x) = triplewell(x...)

Triplewell() = Diffusion(; potential=triplewell, dim=2, sigma=[1.0, 1.0])

defaultmodel(sim::AbstractLangevin; nout) = smallnet(dim(sim), nout)
defaultmodel(sim::AbstractLangevin; nout, activation, layers) = smallnet(dim(sim), nout)
15 changes: 8 additions & 7 deletions src/iso2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ function isosteps(model, opt, (xs, ys), nkoop=1, nupdate=1;
for i in 1:nkoop
target = isotarget(model, xs, ys, transform)
push!(targets, target)
# note that nupdate is classically called epochs
# TODO: should we use minibatches here?
for j in 1:nupdate
loss = ISOKANN.learnstep!(model, xs, target, opt) # Neural Network update
push!(losses, loss)
Expand All @@ -61,7 +63,6 @@ function isosteps(model, opt, (xs, ys), nkoop=1, nupdate=1;
end



isosteps(iso::NamedTuple; kwargs...) = isosteps(; iso..., kwargs...)
function isosteps(; nkoop=1, nupdate=1, exp...)
exp = NamedTuple(exp)
Expand All @@ -82,7 +83,7 @@ If `direct==true` solve `chi * pinv(K(chi))`, otherwise `inv(K(chi) * pinv(chi))
`normalize` specifies whether to renormalize the resulting target vectors.
`permute` specifies whether to permute the target for stability.
"""
@with_kw struct TransformPseudoInv
@kwdef struct TransformPseudoInv
normalize::Bool = true
direct::Bool = true
eigenvecs::Bool = true
Expand Down Expand Up @@ -119,7 +120,7 @@ end
Compute the target via the inner simplex algorithm (without feasiblization routine).
`permute` specifies whether to apply the stabilizing permutation """
@with_kw struct TransformISA
@kwdef struct TransformISA
permute::Bool = true
end

Expand Down Expand Up @@ -194,13 +195,13 @@ function isodata(diffusion, nx, ny)
end

function test_dw(; kwargs...)
iso2(nd=2, sim=Doublewell(); kwargs...)
vismodel(model)
i = iso2(nd=2, sim=Doublewell(); kwargs...)
vismodel(i.model)
end

function test_tw(; kwargs...)
iso2(nd=3, sim=Triplewell(); kwargs...)
vismodel(model)
i = iso2(nd=3, sim=Triplewell(); kwargs...)
vismodel(i.model)
end


Expand Down
8 changes: 4 additions & 4 deletions src/molutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ end

dihedral(x::AbstractMatrix) = @views dihedral(x[:, 1], x[:, 2], x[:, 3], x[:, 4])

function psi(x::AbstractVector) # dihedral of the oxygens
function psi(x::AbstractVector, inds=[7, 9, 15, 17]) # dihedral of the oxygens
x = reshape(x, 3, :)
@views dihedral(x[:, [7, 9, 15, 17]])
@views dihedral(x[:, inds])
end

function phi(x::AbstractVector)
function phi(x::AbstractVector, inds=[5, 7, 9, 15])
x = reshape(x, 3, :)
@views dihedral(x[:, [5, 7, 9, 15]])
@views dihedral(x[:, inds])
end

phi(x::AbstractMatrix) = mapslices(phi, x, dims=1) |> vec
Expand Down
29 changes: 21 additions & 8 deletions src/plots.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,28 +41,30 @@ function plot_learning(iso; subdata=nothing)
!isnothing(subdata) && (data = subdata)

p1 = plot(losses[1:end], yaxis=:log, title="loss", label="trainloss", xlabel="iter")

for tl in filter(l -> isa(l, TrainlossLogger), iso.loggers)
if length(tl.losses) > 1
plot!(tl.xs, tl.losses, label="validationloss")
end
end

xs, ys = data
# TODO: make this more appealing
p2 = if isa(iso.sim, IsoSimulation)
scatter_ramachandran(xs[1:66, :], model)
else
plot(model(xs) |> vec, label="", ylabel="χ", xlabel="frame")
end

p2 = plot_chi(xs, vec(model(xs)))
p3 = scatter_chifix(data, model)
#annotate!(0,0, repr(iso)[1:10])

ps = [p1, p2, p3]
plot(ps..., layout=(length(ps), 1), size=(400, 300 * length(ps)))
end

function plot_chi(xs, chi::AbstractVector)
if size(xs, 1) == 1
scatter(vec(xs), chi)
elseif size(xs, 1) == 2
scatter(xs[1, :], xs[2, :], marker_z=chi, label="")
elseif size(xs, 1) == 66
scatter_ramachandran(xs, chi)
end
end

""" fixed point plot, i.e. x vs model(x) """
function scatter_chifix(data, model)
Expand Down Expand Up @@ -95,3 +97,14 @@ function scatter_ramachandran(x::AbstractMatrix, model=nothing; kwargs...)
xlabel="\\phi", ylabel="\\psi", title="Ramachandran", ; kwargs...
)
end

scatter_ramachandran(x, model; kwargs...) = scatter_ramachandran(x, vec(model(x)))

function scatter_ramachandran(x::AbstractMatrix, z::Union{AbstractVector,Nothing}=nothing; kwargs...)
ph = phi(x)
ps = psi(x)
scatter(ph, ps, marker_z=z, xlims=[-pi, pi], ylims=[-pi, pi],
markersize=3, markerstrokewidth=0, markeralpha=1, markercolor=:tofino,
xlabel="\\phi", ylabel="\\psi", title="Ramachandran", ; kwargs...
)
end
2 changes: 1 addition & 1 deletion src/potentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ mueller_brown_2d_man(x::AbstractArray) = mueller_brown_2d_man(eachslice(x, dims=
mueller_brown = mueller_brown_2d_man

# Extends AbstractLangevin from forced/langevin.jl
@with_kw struct MuellerBrown <: AbstractLangevin
@kwdef struct MuellerBrown <: AbstractLangevin
tmax::Float64 = 0.01
sigma::Float64 = 2.0
end
Expand Down

0 comments on commit cfbf8ea

Please sign in to comment.