Skip to content

Commit

Permalink
clean up use of defaultmodel
Browse files Browse the repository at this point in the history
fixed most of the tests
  • Loading branch information
axsk committed Aug 21, 2024
1 parent 5ea4ae9 commit 194195f
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 36 deletions.
6 changes: 4 additions & 2 deletions src/ISOKANN.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,14 @@ import PCCAPlus
import MLUtils: numobs
import Flux: cpu, gpu

export pairnet#, pairnetn
export pairnet
#export PDB_ACEMD, PDB_1UAO, PDB_diala_water
#export MollyLangevin, propagate, solve#, MollySDE

export propagate

export run!, runadaptive!
export AdamRegularized, pairnet#, Adam
export AdamRegularized, NesterovRegularized
export plot_training, scatter_ramachandran
export reactive_path, save_reactive_path
export cpu, gpu
Expand All @@ -67,6 +67,8 @@ export chis
export SimulationData
export getxs, getys
export exit_rates
export atom_indices
export load_trajectory, save_trajectory

export reactionpath_minimum, reactionpath_ode

Expand Down
15 changes: 6 additions & 9 deletions src/iso2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ end


"""
Iso2(data; opt=AdamRegularized(), model=pairnet(data), gpu=false, kwargs...)
Iso2(data; opt=AdamRegularized(), model=defaultmodel(data), gpu=false, kwargs...)
"""
function Iso2(data; opt=AdamRegularized(), model=pairnet(data), gpu=false, kwargs...)
function Iso2(data; opt=AdamRegularized(), model=defaultmodel(data), gpu=false, kwargs...)
opt = Flux.setup(opt, model)
transform = outputdim(model) == 1 ? TransformShiftscale() : TransformISA()

Expand All @@ -34,13 +34,10 @@ and constructs the Iso2 object. See also Iso2(data; kwargs...)
- `sim::IsoSimulation`: The `IsoSimulation` object.
- `nx::Int`: The number of starting points.
- `nk::Int`: The number of koopman samples.
- `nd::Int`: Dimension of the χ function.
- `nout::Int`: Dimension of the χ function.
"""
function Iso2(sim::IsoSimulation; nx=100, nk=10, nd=1, kwargs...)
data = SimulationData(sim, nx, nk)
model = pairnet(data; nout=nd) # maybe defaultmodel(data) makes sense here?
return Iso2(data; model, kwargs...)
end
Iso2(sim::IsoSimulation; nx=100, nk=2, kwargs...) = Iso2(SimulationData(sim, nx, nk); kwargs...)


#Iso2(iso::IsoRun) = Iso2(iso.model, iso.opt, iso.data, TransformShiftscale(), iso.losses, iso.loggers, iso.minibatch)

Expand Down Expand Up @@ -199,7 +196,7 @@ function exit_rates(x, kx, tau)
x = vec(x)
kx = vec(kx)
P = [x o .- x] \ [kx o .- kx]
return -1 / tau .* Base.log.(diag(P))
return -1 / tau .* [p > 0 ? Base.log(p) : NaN for p in diag(P)]
end

koopman(iso::Iso2) = koopman(iso.model, getys(iso.data))
Expand Down
8 changes: 3 additions & 5 deletions src/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,15 @@ function featurizer(sim)
end
=#

featuredim((xs, ys)::Tuple) = size(xs, 1)
pairnet(data; kwargs...) = pairnet(featuredim(data); kwargs...)

function pairnet((xs, ys)::Tuple; kwargs...)
pairnet(size(xs, 1); kwargs...)
end

""" Fully connected neural network with `layers` layers from `n` to `nout` dimensions.
`features` allows to pass a featurizer as preprocessor,
`activation` determines the activation function for each but the last layer
`lastactivation` can be used to modify the last layers activation function """
function pairnet(n::Int=22; layers=3, features=identity, activation=Flux.sigmoid, lastactivation=identity, nout=1, layernorm=true)
function pairnet(; n::Int, layers=3, features=identity, activation=Flux.sigmoid, lastactivation=identity, nout=1, layernorm=true)
float32(x) = Float32.(x)
nn = Flux.Chain(
#float32,
Expand All @@ -76,7 +74,7 @@ function growmodel(m, n)
end

# Used by AbstractLangevin
function smallnet(nin, nout, activation=nl = Flux.sigmoid, lastactivation=identity)
function smallnet(nin, nout=1, activation=nl = Flux.sigmoid, lastactivation=identity)
model = Flux.Chain(
Flux.Dense(nin, 5, activation),
Flux.Dense(5, 10, activation),
Expand Down
1 change: 1 addition & 0 deletions src/simulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ function features(d::SimulationData, x)
return d.featurizer(x)
end

defaultmodel(d::SimulationData; kwargs...) = defaultmodel(d.sim; n=featuredim(d), kwargs...)
featuredim(d::SimulationData) = size(d.features[1], 1)
nk(d::SimulationData) = size(d.features[2], 2)

Expand Down
2 changes: 1 addition & 1 deletion src/simulators/langevin.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ abstract type AbstractLangevin <: IsoSimulation end

featurizer(::AbstractLangevin) = identity
integrator(::AbstractLangevin) = StochasticDiffEq.EM()
defaultmodel(l::AbstractLangevin; nout, kwargs...) = smallnet(dim(l), nout)
defaultmodel(l::AbstractLangevin; n=dim(l), kwargs...) = smallnet(n; kwargs...)

function SDEProblem(l::AbstractLangevin, x0=randx0(l), T=lagtime(l); dt=dt(l), alg=integrator(l), kwargs...)
drift(x,p,t) = force(l, x)
Expand Down
2 changes: 1 addition & 1 deletion src/simulators/openmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ end

lagtime(sim::OpenMMSimulation) = sim.step * sim.steps
dim(sim::OpenMMSimulation) = return length(getcoords(sim))
defaultmodel(sim::OpenMMSimulation; kwargs...) = ISOKANN.pairnet(sim; kwargs...)
defaultmodel(sim::OpenMMSimulation; kwargs...) = ISOKANN.pairnet(; kwargs...)
pdb(s::OpenMMSimulation) = s.pdb

"""
Expand Down
32 changes: 14 additions & 18 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,30 @@ using ISOKANN
using Test
using CUDA

backends = [cpu]
if CUDA.functional()
CUDA.allowscalar(false)
push!(backends, gpu)
else
@info "No functional GPU found. Marking GPU test as broken."
@test_broken false
@info "No functional GPU found. Skipping GPU tests"
end

@time begin


@testset "ISOKANN.jl" verbose = true begin
@time @testset "ISOKANN.jl" verbose = true begin

simulations = zip([Doublewell(), Triplewell(), MuellerBrown(), ISOKANN.OpenMM.OpenMMSimulation(), ISOKANN.OpenMM.OpenMMSimulation(features=0.3)], ["Doublewell", "Triplewell", "MuellerBrown", "OpenMM", "OpenMM localdists"])

for backend in [cpu, gpu]
for backend in backends

@testset "Running basic system tests" begin
@testset "Running basic system tests on $backend" begin
for (sim, name) in simulations
@testset "Testing ISOKANN with $name ($backend)" begin
@testset "Testing ISOKANN with $name" begin
i = Iso2(sim) |> backend
@test true
run!(i)
@test true
runadaptive!(i, generations=2, nx=1, iter=1)
@test true
ISOKANN.addextrapolates!(i, 1, stepsize=0.01, steps=1)
#ISOKANN.addextrapolates!(i, 1, stepsize=0.01, steps=1)
@test true
end
end
Expand All @@ -37,10 +35,10 @@ end
@testset "Iso2 Transforms ($backend)" begin
sim = MuellerBrown()
for (d, t) in zip([1, 2, 2], [ISOKANN.TransformShiftscale(), ISOKANN.TransformPseudoInv(), ISOKANN.TransformISA()])
@test begin
run!(Iso2(sim, model=pairnet(2, nout=d), transform=t) |> backend)
true
end
@test begin
run!(Iso2(sim, model=pairnet(n=2, nout=d), transform=t) |> backend)
true
end
end
end
end
Expand All @@ -49,13 +47,11 @@ end
@testset "Iso2 and IsoSimulation operations" begin
iso = Iso2(OpenMMSimulation(), nx=10)
iso.data = iso.data[6:10] # data slicing
path = Base.Filesystem.tempname() * ".jld2"
path = Base.Filesystem.tempname() * ".jld2"
ISOKANN.save(path, iso)
isol = ISOKANN.load(path, iso)
isol = ISOKANN.load(path)
@assert iso.data.coords == isol.data.coords
runadaptive!(isol, generations=1, nx=1, iter=1)
@test true
end
end

end

0 comments on commit 194195f

Please sign in to comment.