Skip to content

Commit

Permalink
cleaned up vgv
Browse files Browse the repository at this point in the history
closes #10
  • Loading branch information
axsk committed Mar 1, 2024
1 parent 7c30599 commit 7af3645
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 75 deletions.
1 change: 1 addition & 0 deletions src/ISOKANN.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ using SpecialFunctions: erf
using Plots: plot, plot!, scatter, scatter!

using MLUtils: numobs, getobs, shuffleobs, unsqueeze
import MLUtils: numobs

import ChainRulesCore
import Flux
Expand Down
14 changes: 14 additions & 0 deletions src/iso1.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,20 @@ Base.@kwdef mutable struct IsoRun{T} # takes 10 min
minibatch::Int = 128
end

# TODO: should make these defaults for sim==nothing
IsoRunFixedData(; data, kwargs...) = ISOKANN.IsoRun(;
data=data,
model=ISOKANN.pairnet(data),
nd=1,
minibatch=0,
nx=0, # no chi subsampling,
nres=0, # no resampling,
np=1, # power iterations,
nl=1, # weight updates,
nk=0,
ny=0,
sim=nothing, kwargs...)

optparms(iso::IsoRun) = optparms(iso.opt.layers[2].bias.rule)
optparms(o::Optimisers.OptimiserChain) = map(optparms, o.opts)
optparms(o::Optimisers.WeightDecay) = (; WeightDecay=o.gamma)
Expand Down
17 changes: 16 additions & 1 deletion src/pairdists.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,29 @@ function flatpairdists(x)
return reshape(p, c * c, s...)
end

# TODO: note we return the squred distances here. we should take the sqrt but check whether aladip and so on still work fine
function simplepairdists(x)
p = -2 .* batched_mul(batched_adjoint(x), x) .+ sum(abs2, x, dims=1) .+ PermutedDimsArray(sum(abs2, x, dims=1), (2, 1, 3))
return p
end

using Distances: pairwise, Euclidean
using LinearAlgebra: diagind, UpperTriangular
# using Distances
function batchedpairdists(x)
inds = halfinds(size(x, 2))
dropdims(mapslices(x -> pairwise(Euclidean(), x)[inds], x, dims=(1, 2)), dims=2)
end

function halfinds(n)
a = UpperTriangular(ones(n, n))
a[diagind(a)] .= 0
findall(a .> 0)
end


### custom implementation of multithreaded pairwise distances
function batchedpairdists(x::AbstractArray)
function batchedpairdists_threaded(x::AbstractArray)
ChainRulesCore.@ignore_derivatives begin
d, n, cols = size(x)
out = similar(x, n, n, cols)
Expand Down
117 changes: 43 additions & 74 deletions src/vgv/vgv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using JLD2
using Flux: Flux, cpu, gpu, Dense, Chain
using StatsBase: mean, std
using Plots
import Distances: pairwise, Euclidean

using LinearAlgebra: UpperTriangular
using Optimisers: Optimisers
using Flux: Dense
Expand All @@ -21,68 +21,46 @@ struct VGVData2 <: VGVData
coords
end

function VGVData(dir= "data/luca/VGVAPG/implicit"; lag=1, nx=500, nk=100, nt=10)
xs, ys = alldata(dir, nx, nk, nt)
VGV_DATA_DIR = "/scratch/htc/ldonati/VGVAPG/implicit"
VGV_DATA_5000 = "/scratch/htc/ldonati/VGVAPG/implicit5000"

VGV5000(; kw...) = VGVData(VGV_DATA_5000, nx=5000, nk=10, t=1, natoms=73; kw...)

function VGVData(dir=VGV_DATA_DIR; nx=500, nk=100, t=1, natoms=73)
xs, ys = vgv_readdata(dir, nx, nk, t, natoms)
coords = reshape(xs, :, size(xs,3))
data = pairwisedata(xs, ys[:, :, :, :, lag])
dx = batchedpairdists(xs) ./ 10
dy = batchedpairdists(ys) ./ 10
data = (dx, dy)
VGVData2(dir, data, coords)
end

pdbfile(v::VGVData) = joinpath(v.dir, "input/initial_states/x0_1.pdb")
getcoords(v::VGVData) = v.coords :: AbstractMatrix

function reactioncoord(v::VGVData)
i = findfirst(CartesianIndex(1,71).==halfinds(73))
v.data[1][i,:]
end

# TODO: we need a proper caching idea, probably Memoize or Albert
#=function alldata_cached(args...)
global VGVDATA
@isdefined(VGVDATA) || (VGVDATA = alldata(args...))
return VGVDATA
end
=#

function alldata(dir, nx, nk, nt)
xs = stack((readchemfile("$dir/input/initial_states/x0_$(i-1).pdb", 1)[:, :, 1] for i in 1:nx))

dim, atoms, nx = size(xs)
ys = similar(xs, dim, atoms, nk, nx, nt)

for i in 1:nx, k in 1:nk
ys[:, :, k, i, :] .= readchemfile("$dir/output/final_states/xf_$(i-1)_r$(k-1).dcd")
function vgv_readdata(dir, nx, nk, t, natoms)
xs = zeros(3, natoms, nx)
ys = zeros(3, natoms, nk, nx)
@showprogress for i in 1:nx
xs[:, :, i] .= readchemfile("$dir/input/initial_states/x0_$(i-1).pdb", 1)
for k in 1:nk
ys[:, :, k, i] .= readchemfile("$dir/output/final_states/xf_$(i-1)_r$(k-1).dcd", t)
end
end

xs, ys
end

# TODO: implment batched dists instead
# cf pairwisedists
function pairwisedata(xs, ys)
dx = stack(eachslice(xs, dims=3)) do co
pairwise(Euclidean(), co, dims=2)
end

dy = stack(eachslice(ys, dims=(3, 4))) do co
pairwise(Euclidean(), co, dims=2)
end

inds = halfinds(size(xs, 2))

return dx[inds, :] ./ 10, dy[inds, :, :] ./ 10
end

function halfinds(n)
a = UpperTriangular(ones(n, n))
a[diagind(a)] .= 0
findall(a .> 0)
end


### ISOKANN MODELS

# a copy of lucas python model
function vgv_luca(;v=VGVData(), kwargs...)
function vgv_luca(v::VGVData=VGVData(); kwargs...)
model = lucanet2(size(v.data, 1))
opt = AdamRegularized(5e-4, 1e-5)

Expand All @@ -98,19 +76,12 @@ function vgv_luca(;v=VGVData(), kwargs...)
return iso
end

# TODO: should make these defaults for sim==nothing
IsoRunFixedData(; data, kwargs...) = ISOKANN.IsoRun(;
data=data,
model=ISOKANN.pairnet(data),
nd=1,
minibatch=0,
nx=0, # no chi subsampling,
nres=0, # no resampling,
np=1, # power iterations,
nl=1, # weight updates,
sim=nothing, kwargs...)


vgv_alex(v::VGVData=VGVData(); kw...) = vgv_luca(v;
model=ISOKANN.pairnet(v.data),
opt=ISOKANN.AdamRegularized(1e-3, 1e-3),
nd=100,
nl=1,
kw...)

lucanet1(dim; activation=Flux.sigmoid) = Chain(Dense(dim => 2048, activation),
Dense(2048 => 1024, activation),
Expand All @@ -122,15 +93,6 @@ lucanet2(dim; activation=Flux.sigmoid) = Chain(Dense(dim => 204, activation),
Dense(102 => 51, activation),
Dense(51 => 1, identity))

vgv_alex(; v=VGVDATA(), kw...) = vgv_luca(;
v,
model=ISOKANN.pairnet(v.data),
opt=ISOKANN.AdamRegularized(0e-3, 1e-3),
nd=1000,
nl=1,
kw...)


### OUTPUT

function scatter_reactioncoord(iso, v::VGVData)
Expand All @@ -140,11 +102,8 @@ function scatter_reactioncoord(iso, v::VGVData)
end

function plot_longtraj(iso, v::VGVData)
xs = ISOKANN.IsoMu.readchemfile("$(v.dir)/implicit/output/trajectory.dcd")
inds = halfinds(size(xs, 2))
dx = stack(eachslice(xs, dims=3)) do co
pairwise(Euclidean(), co, dims=2)
end[inds, :] ./ 10
xs = ISOKANN.IsoMu.readchemfile("$(v.dir)/output/trajectory.dcd")
dx = batchedpairdists(xs) ./ 10
vals = iso.model(dx |> gpu) |> cpu |> vec
plot(vals, xlabel="frame #", ylabel="chi", title="long traj")
end
Expand All @@ -161,16 +120,26 @@ end

### EXAMPLE

function vgv_examplerun(v=VGVData())
iso = vgv_alex(;v)
function vgv_examplerun(v=VGVData(), outdir="out/vgvexample")
mkpath(outdir)
iso = vgv_alex(v)
run!(iso)

plot_training(iso) |> display
savefig("$outdir/training.png")

plot_longtraj(iso, v) |> display
savefig("$outdir/longtraj.png")

save_sorted_path(iso, v)
save_sorted_path(iso, v, "$outdir/chisorted.pdb")
save_reactive_path(Iso2(iso), v.coords,
sigma=1,
out="out/vgv/reactionpath.pdb",
out="$outdir/reactionpath.pdb",
source=pdbfile(v))

open("$outdir/parameters.txt", "w") do io
show(io, MIME"text/plain"(), iso)
end

return iso
end

0 comments on commit 7af3645

Please sign in to comment.