Skip to content

Commit

Permalink
fix gradients through cuda features
Browse files Browse the repository at this point in the history
  • Loading branch information
axsk committed May 11, 2024
1 parent 5f36c01 commit 7be59dd
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 47 deletions.
16 changes: 8 additions & 8 deletions scripts/villin.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@ using PyCall

## Config

comment = "momenta"

pdb = "data/villin nowater.pdb"
steps = 10_000
step = 0.002
temp = 310
friction = 0.1
friction = 1
integrator = :langevinmiddle
momenta = true
features = 0.5 # 0 => backbone only
#forcefields = OpenMM.FORCE_AMBER_IMPLICIT # TODO: this shouldnb be an option the way we build addwater now
forcefields = OpenMM.FORCE_AMBER
Expand All @@ -36,7 +39,6 @@ opt = ISOKANN.NesterovRegularized(1e-3, 1e-4)
sigma = 2
maxjump = 1

comment = "lag10ps gamma0.1"
path = "out/villin/$(now())"[1:end-4] * comment

readdata = nothing
Expand All @@ -51,13 +53,11 @@ println("lagtime: $lagtime ns")
println("simtime per generation: $simtime_per_gen ns")

@time "creating system" sim = OpenMMSimulation(;
pdb, steps, forcefields, features, friction, step,
temp, nthreads=1, mmthreads="gpu")
pdb, steps, forcefields, features, friction, step, momenta, temp, nthreads=1, mmthreads="gpu")

if addwater
@time "adding water" sim = OpenMMSimulation(;
pdb, steps, forcefields, friction, step,
temp, nthreads=1, mmthreads="gpu",
pdb, steps, forcefields, friction, step, momenta, temp, nthreads=1, mmthreads="gpu",
features=sim.features, addwater=true, padding, ionicstrength)
end

Expand Down Expand Up @@ -128,8 +128,8 @@ for i in 1:generations
ISOKANN.Plots.savefig(plot_training(iso), "$path/villin_fold_$(simtime)ps.png")
println("\n status: $path/villin_fold_$(simtime)ps.png \n")
ISOKANN.save("$path/iso.jld2", iso)
catch e
@show e
catch
@show catch_backtrace()
end
end
end
9 changes: 4 additions & 5 deletions src/extrapolate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ extrapolate them by `stepsize` for `steps` steps beyond their extrema,
resulting in 2n new points.
If `minimize` is true, the new points are energy minimized.
"""
function extrapolate(iso, n, stepsize=0.1, steps=1, minimize=true)
function extrapolate(iso, n::Integer, stepsize=0.1, steps=1, minimize=true)
data = iso.data
model = iso.model
coords = flattenlast(data.coords[2])
Expand All @@ -41,7 +41,7 @@ function extrapolate(iso, n, stepsize=0.1, steps=1, minimize=true)
for (p, dir, N) in [(p, -1, n), (reverse(p), 1, 2 * n)]
for i in p
try
x = extrapolate(data, model, coords[:, i], dir * stepsize, steps)
x = extrapolate(iso, coords[:, i], dir * stepsize, steps)
minimize && (x = energyminimization_chilevel(iso, x))
push!(xs, x)
catch e
Expand All @@ -60,12 +60,11 @@ function extrapolate(iso, n, stepsize=0.1, steps=1, minimize=true)
return xs
end

function extrapolate(d, model, x::AbstractVector, step, steps)
function extrapolate(iso, x::AbstractVector, step, steps)
x = copy(x)
for _ in 1:steps
grad = dchidx(d, model, x)
grad = dchidx(iso, x)
x .+= grad ./ norm(grad)^2 .* step
#@show model(features(d,x))
end
return x
end
Expand Down
2 changes: 1 addition & 1 deletion src/iso2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ function train_batch!(model, xs::AbstractMatrix, ys::AbstractMatrix, opt, miniba
end

chis(iso::Iso2) = iso.model(getxs(iso.data))
chicoords(iso::Iso2, xs) = iso.model(iso.data.featurizer(xs))
chicoords(iso::Iso2, xs) = iso.model(features(iso.data, iscuda(iso.model) ? gpu(xs) : xs))
isotarget(iso::Iso2) = isotarget(iso.model, getobs(iso.data)..., iso.transform)

Optimisers.adjust!(iso::Iso2; kwargs...) = Optimisers.adjust!(iso.opt; kwargs...)
Expand Down
4 changes: 3 additions & 1 deletion src/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ inputdim(model::Flux.Dense) = size(model.weight, 2)
outputdim(model::Flux.Chain) = outputdim(model.layers[end])
outputdim(model::Flux.Dense) = size(model.weight, 1)

iscuda(m::Flux.Chain) = Flux.params(m)[1] isa CuArray



""" convenience wrapper returning the provided model with the default AdamW optimiser """
Expand Down Expand Up @@ -47,7 +49,7 @@ function pairnet((xs, ys)::Tuple; kwargs...)
end

""" Fully connected neural network with `layers` layers from `n` to `nout` dimensions.
`features` allows to pass a featurizer as preprocessor,
`features` allows to pass a featurizer as preprocessor,
`activation` determines the activation function for each but the last layer
`lastactivation` can be used to modify the last layers activation function """
function pairnet(n::Int=22; layers=3, features=identity, activation=Flux.sigmoid, lastactivation=identity, nout=1)
Expand Down
21 changes: 8 additions & 13 deletions src/reactionpath.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,17 @@
using LinearAlgebra: normalize

function dchidx(iso, x=getcoords(iso.data)[:, 1])
dchidx(iso.data, iso.model, x)
end

function dchidx(data, model, x)
function dchidx(iso, x)
Zygote.gradient(x) do x
model(features(data, x)) |> myonly
end[1]
chicoords(iso, x) |> myonly
end |> only
end


## works on gpu as well
function myonly(x)
if length(x) == 1
return sum(x)
else
error("only scalar net is supported here")
end
myonly(x) = only(x)
function myonly(x::CuArray)
@assert length(x) == 1
return sum(x)
end

"""
Expand Down
5 changes: 4 additions & 1 deletion src/simulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,10 @@ end
gpu(d::SimulationData) = SimulationData(d.sim, gpu(d.features), d.coords, d.featurizer)
cpu(d::SimulationData) = SimulationData(d.sim, cpu(d.features), d.coords, d.featurizer)

features(d::SimulationData, x) = d.featurizer(x)
function features(d::SimulationData, x)
d.features[1] isa CuArray && (x = cu(x))
return d.featurizer(x)
end

featuredim(d::SimulationData) = size(d.features[1], 1)
nk(d::SimulationData) = size(d.features[2], 2)
Expand Down
36 changes: 18 additions & 18 deletions src/simulators/openmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,20 +99,6 @@ function OpenMMSimulation(;
return OpenMMSimulation(pysim::PyObject, pdb, ligand, forcefields, temp, friction, step, steps, features, nthreads, mmthreads, momenta)
end

function Base.show(io::IO, mime::MIME"text/plain", sim::OpenMMSimulation)#
println(
io, """
OpenMMSimulation(;
pdb="$(sim.pdb)",
ligand="$(sim.ligand)",
forcefields=$(sim.forcefields),
temp=$(sim.temp),
friction=$(sim.friction),
step=$(sim.step),
steps=$(sim.steps),
features=$(sim.features))"""
)
end

function featurizer(sim::OpenMMSimulation)
if sim.features isa (Vector{Int})
Expand Down Expand Up @@ -162,18 +148,18 @@ function propagate(s::OpenMMSimulation, x0::AbstractMatrix{T}, ny; stepsize=s.st
dim, nx = size(x0)
xs = repeat(x0, outer=[1, ny])
xs = permutedims(reinterpret(Tuple{T,T,T}, xs))
ys = @pycall py"threadedrun"(xs, s.pysim, stepsize, steps, nthreads, mmthreads, momenta)::PyArray
ys = @pycall py"threadedrun"(xs, s.pysim, stepsize, steps, nthreads, mmthreads, momenta)::Vector{Float32}
ys = reshape(ys, dim, nx, ny)
ys = permutedims(ys, (1, 3, 2))
checkoverflow(ys) # control the simulated data for NaNs and too large entries and throws an error
return convert(AbstractArray{Float32}, ys)
return ys#convert(Array{Float32,3}, ys)
end

propagate(s::OpenMMSimulation, x0::CuArray, ny; nthreads=Threads.nthreads()) = cu(propagate(s, collect(x0), ny; nthreads))
#propagate(s::OpenMMSimulation, x0::CuArray, ny; nthreads=Threads.nthreads()) = cu(propagate(s, collect(x0), ny; nthreads))

struct OpenMMOverflow{T} <: Exception where {T}
result::T
select::Vector{Bool}
select::Vector{Bool} # flags which results are valid
end

function checkoverflow(ys, overflow=100)
Expand Down Expand Up @@ -332,6 +318,20 @@ Base.convert(::Type{OpenMMSimulation}, a::OpenMMSimulationSerialized) =
nthreads=a.nthreads,
mmthreads=a.mmthreads)

function Base.show(io::IO, mime::MIME"text/plain", sim::OpenMMSimulation)#
println(
io, """
OpenMMSimulation(;
pdb="$(sim.pdb)",
ligand="$(sim.ligand)",
forcefields=$(sim.forcefields),
temp=$(sim.temp),
friction=$(sim.friction),
step=$(sim.step),
steps=$(sim.steps),
features=$(sim.features))"""
)
end

end #module

Expand Down

0 comments on commit 7be59dd

Please sign in to comment.