Skip to content

Commit

Permalink
mainly documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
axsk committed Jan 18, 2024
1 parent 49a8d3a commit bc7b993
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 9 deletions.
2 changes: 0 additions & 2 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,11 @@ propagate
```@autodocs
Modules = [ISOKANN]
Private = false
Order = [:function]
```

## Internal API

```@autodocs
Modules = [ISOKANN]
Public = false
Order = [:function]
```
4 changes: 2 additions & 2 deletions src/data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ end
Returns `n` indices of `xs` such that `model(xs[inds])` is approximately uniformly distributed.
"""
function subsample_inds(model, xs, n)
reduce(vcat, eachrow(model(xs))) do row
mapreduce(vcat, eachrow(model(xs))) do row
subsample_uniformgrid(shiftscale(row), n)
end::Vector{Int}
end
Expand All @@ -38,7 +38,7 @@ end
subsample(model, data::Array, n) :: Matrix
subsample(model, data::Tuple, n) :: Tuple
Subsample `n`` points of `data` uniformly in `model`.
Subsample `n` points of `data` uniformly in `model`.
If `model` returns multiple values per sample, subsample along each dimension.
"""
subsample(model, xs::AbstractArray{<:Any,2}, n) =
Expand Down
27 changes: 24 additions & 3 deletions src/isomolly.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,27 @@
"""
struct IsoRun{T}
The `IsoRun` struct represents a configuration for running the Isomolly algorithm.
The `IsoRun` struct represents a configuration for running the ISOKANN algorithm with adaptive sampling.
The whole algorithm consists of three nested loop
1. `nd` iterations of the data loop where `nx` points are subsampled (via stratified χ-subsampling) from the pool of all available data
2. `np` iterations of the power iteration where the training target is determined with the current model and subdata
3. `nl` iterations of the SGD updates to the neural network model to learn the current target
On initialization it samples `ny` starting positions with `nk` Koopman samples each.
Furthermore if `nres` > 0 it samples `ny` new data points adaptively starting from χ-sampled positions every `nres` steps in the data loop.
The `sim` field takes any simulation object that implements the data sampling interface (mainly the `propagate` method, see data.jl),
usually a `MollyLangevin` simulation.
`model` and `opt` store the neural network model and the optimizert (defaulting to a `pairnet` and `AdamRegularized`).
`data` contains the training data and is by default constructed using the `bootstrap` method.
The vector `losses` keeps track of the training loss and `loggers` allows to pass in logging functions which are executed in the power iteration loop.
To start the actual training call the `run!` method.
# Fields
- `nd::Int64`: Number of outer data subsampling steps.
Expand All @@ -21,7 +41,6 @@ The `IsoRun` struct represents a configuration for running the Isomolly algorith
- `loggers::Vector`: Vector of loggers.
"""

Base.@kwdef mutable struct IsoRun{T} # takes 10 min
nd::Int64 = 1000 # number of outer datasubsampling steps
nx::Int64 = 100 # size of subdata set
Expand Down Expand Up @@ -92,6 +111,8 @@ function run!(iso::IsoRun; showprogress=true)
return iso
end

log(f::Function; kwargs) == f(; kwargs...)

# note there is also plot_callback in isokann.jl
function autoplot(secs=10)
Flux.throttle(
Expand All @@ -116,7 +137,7 @@ end
""" empirical shift-scale operation """
shiftscale(ks) = (ks .- minimum(ks)) ./ (maximum(ks) - minimum(ks))

""" batched supervised learning for a given batchsize """
""" DEPRECATED - batched supervised learning for a given batchsize """
function learnbatch!(model, xs::AbstractMatrix, target::AbstractVector, opt, batchsize)
ndata = length(target)
if ndata <= batchsize || batchsize == 0
Expand Down
4 changes: 4 additions & 0 deletions src/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ function featureinds(sim::IsoSimulation)
end
end

""" convenience wrapper returning the provided model with the default AdamW optimiser """
model_with_opt(model, learnrate=1e-2, decay=1e-5) =
(; model, opt=Flux.setup(Flux.AdamW(learnrate, (0.9, 0.999), decay), model))

""" given an array of arbitrary shape, select the rows `inds` in the first dimension """
function selectrows(x, inds)
d, s... = size(x)
Expand Down
5 changes: 3 additions & 2 deletions src/molly.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ Simulates the overdamped Langevin equation using the Euler-Maruyama method with
with σ = sqrt(2KT/(mγ))
dX = (-∇U(X)/(γm) + σu) dt + σ dW
where u is the control function, such that u(x,t) = σ .* w(x,t)
where u is the control function, such that u(x,t) = σ .* w(x,t).
The accumulated Girsanov reweighting is stored in the field `g`
# Arguments
- `dt::S`: the time step of the simulation.
Expand All @@ -172,7 +173,7 @@ struct OverdampedLangevinGirsanov{S, K, F, Fct}
friction::F
remove_CM_motion::Int
g::Float64 # the Girsanov integral
w::Fct # control function in the form w = uσ
w::Fct
end

function OverdampedLangevinGirsanov(; dt, temperature, friction, w, remove_CM_motion=1, G=0.)
Expand Down
3 changes: 3 additions & 0 deletions src/simulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ end
Burst simulation of the MollyLangeving system `ms`.
Propagates `ny` samples for each initial position provided in the columns of `x0`.
`propagate` is the main interface facilitating sampling of a system.
TODO: specify the actual interface required for a simulation to be runnable by ISOKANN.
# Arguments
- `ms::MollyLangevin`: The MollyLangevin solver object.
- `x0::AbstractMatrix`: The initial positions matrix.
Expand Down

0 comments on commit bc7b993

Please sign in to comment.