From 7be59ddb1fedadd60d7f7a1b56d280e89cbe7418 Mon Sep 17 00:00:00 2001 From: Sikorski Date: Sat, 11 May 2024 19:27:51 +0200 Subject: [PATCH] fix gradients through cuda features --- scripts/villin.jl | 16 ++++++++-------- src/extrapolate.jl | 9 ++++----- src/iso2.jl | 2 +- src/models.jl | 4 +++- src/reactionpath.jl | 21 ++++++++------------- src/simulation.jl | 5 ++++- src/simulators/openmm.jl | 36 ++++++++++++++++++------------------ 7 files changed, 46 insertions(+), 47 deletions(-) diff --git a/scripts/villin.jl b/scripts/villin.jl index 20054e4..e668203 100644 --- a/scripts/villin.jl +++ b/scripts/villin.jl @@ -5,12 +5,15 @@ using PyCall ## Config +comment = "momenta" + pdb = "data/villin nowater.pdb" steps = 10_000 step = 0.002 temp = 310 -friction = 0.1 +friction = 1 integrator = :langevinmiddle +momenta = true features = 0.5 # 0 => backbone only #forcefields = OpenMM.FORCE_AMBER_IMPLICIT # TODO: this shouldnb be an option the way we build addwater now forcefields = OpenMM.FORCE_AMBER @@ -36,7 +39,6 @@ opt = ISOKANN.NesterovRegularized(1e-3, 1e-4) sigma = 2 maxjump = 1 -comment = "lag10ps gamma0.1" path = "out/villin/$(now())"[1:end-4] * comment readdata = nothing @@ -51,13 +53,11 @@ println("lagtime: $lagtime ns") println("simtime per generation: $simtime_per_gen ns") @time "creating system" sim = OpenMMSimulation(; - pdb, steps, forcefields, features, friction, step, - temp, nthreads=1, mmthreads="gpu") + pdb, steps, forcefields, features, friction, step, momenta, temp, nthreads=1, mmthreads="gpu") if addwater @time "adding water" sim = OpenMMSimulation(; - pdb, steps, forcefields, friction, step, - temp, nthreads=1, mmthreads="gpu", + pdb, steps, forcefields, friction, step, momenta, temp, nthreads=1, mmthreads="gpu", features=sim.features, addwater=true, padding, ionicstrength) end @@ -128,8 +128,8 @@ for i in 1:generations ISOKANN.Plots.savefig(plot_training(iso), "$path/villin_fold_$(simtime)ps.png") println("\n status: $path/villin_fold_$(simtime)ps.png \n") ISOKANN.save("$path/iso.jld2", iso) - catch e - @show e + catch + @show catch_backtrace() end end end diff --git a/src/extrapolate.jl b/src/extrapolate.jl index 2495fac..d026ef4 100644 --- a/src/extrapolate.jl +++ b/src/extrapolate.jl @@ -28,7 +28,7 @@ extrapolate them by `stepsize` for `steps` steps beyond their extrema, resulting in 2n new points. If `minimize` is true, the new points are energy minimized. """ -function extrapolate(iso, n, stepsize=0.1, steps=1, minimize=true) +function extrapolate(iso, n::Integer, stepsize=0.1, steps=1, minimize=true) data = iso.data model = iso.model coords = flattenlast(data.coords[2]) @@ -41,7 +41,7 @@ function extrapolate(iso, n, stepsize=0.1, steps=1, minimize=true) for (p, dir, N) in [(p, -1, n), (reverse(p), 1, 2 * n)] for i in p try - x = extrapolate(data, model, coords[:, i], dir * stepsize, steps) + x = extrapolate(iso, coords[:, i], dir * stepsize, steps) minimize && (x = energyminimization_chilevel(iso, x)) push!(xs, x) catch e @@ -60,12 +60,11 @@ function extrapolate(iso, n, stepsize=0.1, steps=1, minimize=true) return xs end -function extrapolate(d, model, x::AbstractVector, step, steps) +function extrapolate(iso, x::AbstractVector, step, steps) x = copy(x) for _ in 1:steps - grad = dchidx(d, model, x) + grad = dchidx(iso, x) x .+= grad ./ norm(grad)^2 .* step - #@show model(features(d,x)) end return x end diff --git a/src/iso2.jl b/src/iso2.jl index b7fd3d1..5cf697b 100644 --- a/src/iso2.jl +++ b/src/iso2.jl @@ -105,7 +105,7 @@ function train_batch!(model, xs::AbstractMatrix, ys::AbstractMatrix, opt, miniba end chis(iso::Iso2) = iso.model(getxs(iso.data)) -chicoords(iso::Iso2, xs) = iso.model(iso.data.featurizer(xs)) +chicoords(iso::Iso2, xs) = iso.model(features(iso.data, iscuda(iso.model) ? gpu(xs) : xs)) isotarget(iso::Iso2) = isotarget(iso.model, getobs(iso.data)..., iso.transform) Optimisers.adjust!(iso::Iso2; kwargs...) = Optimisers.adjust!(iso.opt; kwargs...) diff --git a/src/models.jl b/src/models.jl index 6cb9ecd..ccd8c07 100644 --- a/src/models.jl +++ b/src/models.jl @@ -19,6 +19,8 @@ inputdim(model::Flux.Dense) = size(model.weight, 2) outputdim(model::Flux.Chain) = outputdim(model.layers[end]) outputdim(model::Flux.Dense) = size(model.weight, 1) +iscuda(m::Flux.Chain) = Flux.params(m)[1] isa CuArray + """ convenience wrapper returning the provided model with the default AdamW optimiser """ @@ -47,7 +49,7 @@ function pairnet((xs, ys)::Tuple; kwargs...) end """ Fully connected neural network with `layers` layers from `n` to `nout` dimensions. -`features` allows to pass a featurizer as preprocessor, +`features` allows to pass a featurizer as preprocessor, `activation` determines the activation function for each but the last layer `lastactivation` can be used to modify the last layers activation function """ function pairnet(n::Int=22; layers=3, features=identity, activation=Flux.sigmoid, lastactivation=identity, nout=1) diff --git a/src/reactionpath.jl b/src/reactionpath.jl index 5ebb9eb..194cc42 100644 --- a/src/reactionpath.jl +++ b/src/reactionpath.jl @@ -1,22 +1,17 @@ using LinearAlgebra: normalize -function dchidx(iso, x=getcoords(iso.data)[:, 1]) - dchidx(iso.data, iso.model, x) -end - -function dchidx(data, model, x) +function dchidx(iso, x) Zygote.gradient(x) do x - model(features(data, x)) |> myonly - end[1] + chicoords(iso, x) |> myonly + end |> only end + ## works on gpu as well -function myonly(x) - if length(x) == 1 - return sum(x) - else - error("only scalar net is supported here") - end +myonly(x) = only(x) +function myonly(x::CuArray) + @assert length(x) == 1 + return sum(x) end """ diff --git a/src/simulation.jl b/src/simulation.jl index 36f4802..01d5f19 100644 --- a/src/simulation.jl +++ b/src/simulation.jl @@ -93,7 +93,10 @@ end gpu(d::SimulationData) = SimulationData(d.sim, gpu(d.features), d.coords, d.featurizer) cpu(d::SimulationData) = SimulationData(d.sim, cpu(d.features), d.coords, d.featurizer) -features(d::SimulationData, x) = d.featurizer(x) +function features(d::SimulationData, x) + d.features[1] isa CuArray && (x = cu(x)) + return d.featurizer(x) +end featuredim(d::SimulationData) = size(d.features[1], 1) nk(d::SimulationData) = size(d.features[2], 2) diff --git a/src/simulators/openmm.jl b/src/simulators/openmm.jl index fa08669..23b69c8 100644 --- a/src/simulators/openmm.jl +++ b/src/simulators/openmm.jl @@ -99,20 +99,6 @@ function OpenMMSimulation(; return OpenMMSimulation(pysim::PyObject, pdb, ligand, forcefields, temp, friction, step, steps, features, nthreads, mmthreads, momenta) end -function Base.show(io::IO, mime::MIME"text/plain", sim::OpenMMSimulation)# - println( - io, """ - OpenMMSimulation(; - pdb="$(sim.pdb)", - ligand="$(sim.ligand)", - forcefields=$(sim.forcefields), - temp=$(sim.temp), - friction=$(sim.friction), - step=$(sim.step), - steps=$(sim.steps), - features=$(sim.features))""" - ) -end function featurizer(sim::OpenMMSimulation) if sim.features isa (Vector{Int}) @@ -162,18 +148,18 @@ function propagate(s::OpenMMSimulation, x0::AbstractMatrix{T}, ny; stepsize=s.st dim, nx = size(x0) xs = repeat(x0, outer=[1, ny]) xs = permutedims(reinterpret(Tuple{T,T,T}, xs)) - ys = @pycall py"threadedrun"(xs, s.pysim, stepsize, steps, nthreads, mmthreads, momenta)::PyArray + ys = @pycall py"threadedrun"(xs, s.pysim, stepsize, steps, nthreads, mmthreads, momenta)::Vector{Float32} ys = reshape(ys, dim, nx, ny) ys = permutedims(ys, (1, 3, 2)) checkoverflow(ys) # control the simulated data for NaNs and too large entries and throws an error - return convert(AbstractArray{Float32}, ys) + return ys#convert(Array{Float32,3}, ys) end -propagate(s::OpenMMSimulation, x0::CuArray, ny; nthreads=Threads.nthreads()) = cu(propagate(s, collect(x0), ny; nthreads)) +#propagate(s::OpenMMSimulation, x0::CuArray, ny; nthreads=Threads.nthreads()) = cu(propagate(s, collect(x0), ny; nthreads)) struct OpenMMOverflow{T} <: Exception where {T} result::T - select::Vector{Bool} + select::Vector{Bool} # flags which results are valid end function checkoverflow(ys, overflow=100) @@ -332,6 +318,20 @@ Base.convert(::Type{OpenMMSimulation}, a::OpenMMSimulationSerialized) = nthreads=a.nthreads, mmthreads=a.mmthreads) +function Base.show(io::IO, mime::MIME"text/plain", sim::OpenMMSimulation)# + println( + io, """ + OpenMMSimulation(; + pdb="$(sim.pdb)", + ligand="$(sim.ligand)", + forcefields=$(sim.forcefields), + temp=$(sim.temp), + friction=$(sim.friction), + step=$(sim.step), + steps=$(sim.steps), + features=$(sim.features))""" + ) +end end #module