From 7af36450cfe5cf03ec9444422df8ad3a00134f95 Mon Sep 17 00:00:00 2001 From: Sikorski Date: Fri, 1 Mar 2024 14:04:49 +0100 Subject: [PATCH] cleaned up vgv closes #10 --- src/ISOKANN.jl | 1 + src/iso1.jl | 14 ++++++ src/pairdists.jl | 17 ++++++- src/vgv/vgv.jl | 117 +++++++++++++++++------------------------------ 4 files changed, 74 insertions(+), 75 deletions(-) diff --git a/src/ISOKANN.jl b/src/ISOKANN.jl index 5ca448b..186fae7 100644 --- a/src/ISOKANN.jl +++ b/src/ISOKANN.jl @@ -21,6 +21,7 @@ using SpecialFunctions: erf using Plots: plot, plot!, scatter, scatter! using MLUtils: numobs, getobs, shuffleobs, unsqueeze +import MLUtils: numobs import ChainRulesCore import Flux diff --git a/src/iso1.jl b/src/iso1.jl index ddb241f..12c35ac 100644 --- a/src/iso1.jl +++ b/src/iso1.jl @@ -61,6 +61,20 @@ Base.@kwdef mutable struct IsoRun{T} # takes 10 min minibatch::Int = 128 end +# TODO: should make these defaults for sim==nothing +IsoRunFixedData(; data, kwargs...) = ISOKANN.IsoRun(; + data=data, + model=ISOKANN.pairnet(data), + nd=1, + minibatch=0, + nx=0, # no chi subsampling, + nres=0, # no resampling, + np=1, # power iterations, + nl=1, # weight updates, + nk=0, + ny=0, + sim=nothing, kwargs...) + optparms(iso::IsoRun) = optparms(iso.opt.layers[2].bias.rule) optparms(o::Optimisers.OptimiserChain) = map(optparms, o.opts) optparms(o::Optimisers.WeightDecay) = (; WeightDecay=o.gamma) diff --git a/src/pairdists.jl b/src/pairdists.jl index a752f6b..62bccfa 100644 --- a/src/pairdists.jl +++ b/src/pairdists.jl @@ -14,14 +14,29 @@ function flatpairdists(x) return reshape(p, c * c, s...) end +# TODO: note we return the squred distances here. we should take the sqrt but check whether aladip and so on still work fine function simplepairdists(x) p = -2 .* batched_mul(batched_adjoint(x), x) .+ sum(abs2, x, dims=1) .+ PermutedDimsArray(sum(abs2, x, dims=1), (2, 1, 3)) return p end +using Distances: pairwise, Euclidean +using LinearAlgebra: diagind, UpperTriangular +# using Distances +function batchedpairdists(x) + inds = halfinds(size(x, 2)) + dropdims(mapslices(x -> pairwise(Euclidean(), x)[inds], x, dims=(1, 2)), dims=2) +end + +function halfinds(n) + a = UpperTriangular(ones(n, n)) + a[diagind(a)] .= 0 + findall(a .> 0) +end + ### custom implementation of multithreaded pairwise distances -function batchedpairdists(x::AbstractArray) +function batchedpairdists_threaded(x::AbstractArray) ChainRulesCore.@ignore_derivatives begin d, n, cols = size(x) out = similar(x, n, n, cols) diff --git a/src/vgv/vgv.jl b/src/vgv/vgv.jl index 4999b55..722f7ba 100644 --- a/src/vgv/vgv.jl +++ b/src/vgv/vgv.jl @@ -5,7 +5,7 @@ using JLD2 using Flux: Flux, cpu, gpu, Dense, Chain using StatsBase: mean, std using Plots -import Distances: pairwise, Euclidean + using LinearAlgebra: UpperTriangular using Optimisers: Optimisers using Flux: Dense @@ -21,68 +21,46 @@ struct VGVData2 <: VGVData coords end -function VGVData(dir= "data/luca/VGVAPG/implicit"; lag=1, nx=500, nk=100, nt=10) - xs, ys = alldata(dir, nx, nk, nt) +VGV_DATA_DIR = "/scratch/htc/ldonati/VGVAPG/implicit" +VGV_DATA_5000 = "/scratch/htc/ldonati/VGVAPG/implicit5000" + +VGV5000(; kw...) = VGVData(VGV_DATA_5000, nx=5000, nk=10, t=1, natoms=73; kw...) + +function VGVData(dir=VGV_DATA_DIR; nx=500, nk=100, t=1, natoms=73) + xs, ys = vgv_readdata(dir, nx, nk, t, natoms) coords = reshape(xs, :, size(xs,3)) - data = pairwisedata(xs, ys[:, :, :, :, lag]) + dx = batchedpairdists(xs) ./ 10 + dy = batchedpairdists(ys) ./ 10 + data = (dx, dy) VGVData2(dir, data, coords) end pdbfile(v::VGVData) = joinpath(v.dir, "input/initial_states/x0_1.pdb") getcoords(v::VGVData) = v.coords :: AbstractMatrix + function reactioncoord(v::VGVData) i = findfirst(CartesianIndex(1,71).==halfinds(73)) v.data[1][i,:] end # TODO: we need a proper caching idea, probably Memoize or Albert -#=function alldata_cached(args...) - global VGVDATA - @isdefined(VGVDATA) || (VGVDATA = alldata(args...)) - return VGVDATA -end -=# - -function alldata(dir, nx, nk, nt) - xs = stack((readchemfile("$dir/input/initial_states/x0_$(i-1).pdb", 1)[:, :, 1] for i in 1:nx)) - dim, atoms, nx = size(xs) - ys = similar(xs, dim, atoms, nk, nx, nt) - - for i in 1:nx, k in 1:nk - ys[:, :, k, i, :] .= readchemfile("$dir/output/final_states/xf_$(i-1)_r$(k-1).dcd") +function vgv_readdata(dir, nx, nk, t, natoms) + xs = zeros(3, natoms, nx) + ys = zeros(3, natoms, nk, nx) + @showprogress for i in 1:nx + xs[:, :, i] .= readchemfile("$dir/input/initial_states/x0_$(i-1).pdb", 1) + for k in 1:nk + ys[:, :, k, i] .= readchemfile("$dir/output/final_states/xf_$(i-1)_r$(k-1).dcd", t) + end end - xs, ys end -# TODO: implment batched dists instead -# cf pairwisedists -function pairwisedata(xs, ys) - dx = stack(eachslice(xs, dims=3)) do co - pairwise(Euclidean(), co, dims=2) - end - - dy = stack(eachslice(ys, dims=(3, 4))) do co - pairwise(Euclidean(), co, dims=2) - end - - inds = halfinds(size(xs, 2)) - - return dx[inds, :] ./ 10, dy[inds, :, :] ./ 10 -end - -function halfinds(n) - a = UpperTriangular(ones(n, n)) - a[diagind(a)] .= 0 - findall(a .> 0) -end - - ### ISOKANN MODELS # a copy of lucas python model -function vgv_luca(;v=VGVData(), kwargs...) +function vgv_luca(v::VGVData=VGVData(); kwargs...) model = lucanet2(size(v.data, 1)) opt = AdamRegularized(5e-4, 1e-5) @@ -98,19 +76,12 @@ function vgv_luca(;v=VGVData(), kwargs...) return iso end -# TODO: should make these defaults for sim==nothing -IsoRunFixedData(; data, kwargs...) = ISOKANN.IsoRun(; - data=data, - model=ISOKANN.pairnet(data), - nd=1, - minibatch=0, - nx=0, # no chi subsampling, - nres=0, # no resampling, - np=1, # power iterations, - nl=1, # weight updates, - sim=nothing, kwargs...) - - +vgv_alex(v::VGVData=VGVData(); kw...) = vgv_luca(v; + model=ISOKANN.pairnet(v.data), + opt=ISOKANN.AdamRegularized(1e-3, 1e-3), + nd=100, + nl=1, + kw...) lucanet1(dim; activation=Flux.sigmoid) = Chain(Dense(dim => 2048, activation), Dense(2048 => 1024, activation), @@ -122,15 +93,6 @@ lucanet2(dim; activation=Flux.sigmoid) = Chain(Dense(dim => 204, activation), Dense(102 => 51, activation), Dense(51 => 1, identity)) -vgv_alex(; v=VGVDATA(), kw...) = vgv_luca(; - v, - model=ISOKANN.pairnet(v.data), - opt=ISOKANN.AdamRegularized(0e-3, 1e-3), - nd=1000, - nl=1, - kw...) - - ### OUTPUT function scatter_reactioncoord(iso, v::VGVData) @@ -140,11 +102,8 @@ function scatter_reactioncoord(iso, v::VGVData) end function plot_longtraj(iso, v::VGVData) - xs = ISOKANN.IsoMu.readchemfile("$(v.dir)/implicit/output/trajectory.dcd") - inds = halfinds(size(xs, 2)) - dx = stack(eachslice(xs, dims=3)) do co - pairwise(Euclidean(), co, dims=2) - end[inds, :] ./ 10 + xs = ISOKANN.IsoMu.readchemfile("$(v.dir)/output/trajectory.dcd") + dx = batchedpairdists(xs) ./ 10 vals = iso.model(dx |> gpu) |> cpu |> vec plot(vals, xlabel="frame #", ylabel="chi", title="long traj") end @@ -161,16 +120,26 @@ end ### EXAMPLE -function vgv_examplerun(v=VGVData()) - iso = vgv_alex(;v) +function vgv_examplerun(v=VGVData(), outdir="out/vgvexample") + mkpath(outdir) + iso = vgv_alex(v) run!(iso) + plot_training(iso) |> display + savefig("$outdir/training.png") + + plot_longtraj(iso, v) |> display + savefig("$outdir/longtraj.png") - save_sorted_path(iso, v) + save_sorted_path(iso, v, "$outdir/chisorted.pdb") save_reactive_path(Iso2(iso), v.coords, sigma=1, - out="out/vgv/reactionpath.pdb", + out="$outdir/reactionpath.pdb", source=pdbfile(v)) + open("$outdir/parameters.txt", "w") do io + show(io, MIME"text/plain"(), iso) + end + return iso end \ No newline at end of file