Skip to content

Commit

Permalink
add new kde sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
axsk committed Nov 20, 2024
1 parent e203289 commit e9e1c36
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 7 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Sikorski <[email protected]> and contributors"]
version = "1.0.0"

[deps]
AverageShiftedHistograms = "77b51b56-6f8f-5c3a-9cb4-d71f9594ea6e"
Bonito = "824d6782-a2ef-11e9-3a09-e5662e0c26f8"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down Expand Up @@ -52,6 +53,7 @@ Molly = "aa0f7f06-fcc0-5ec4-a7f3-a573f33f9c4c"
MollyExt = "Molly"

[compat]
AverageShiftedHistograms = "0.8.9"
Bonito = "3.1.2"
CUDA = "5.2.0"
ChainRulesCore = "1.23.0"
Expand Down
4 changes: 2 additions & 2 deletions src/iso.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
data::D
transform
losses = Float64[]
loggers = [autoplot(1)]
loggers = Any[autoplot(1)]
minibatch = 100
end

Expand All @@ -20,7 +20,7 @@ function Iso(data;
gpu=CUDA.has_cuda(),
autoplot=0,
validation=nothing,
loggers::Vector{Any}=[],
loggers=[],
kwargs...)

opt = Flux.setup(opt, model)
Expand Down
13 changes: 8 additions & 5 deletions src/simulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,17 +170,20 @@ function chistratcoords(d::SimulationData, model, n; keepedges=false)
end


function resample_kde(data::SimulationData, model, n; bandwidth=0.02, unique=true)
function resample_kde(data::SimulationData, model, n; bandwidth=0.02, unique=false)
n == 0 && return data

sampled = Set(eachcol(data.coords[1]))
selinds = unique ? [i for (i, c) in enumerate(eachcol(values(data.coords[2]) |> flattenlast)) if !(c in sampled)] : (:)

selinds = if unique
sampled = Set(eachcol(data.coords[1]))
[i for (i, c) in enumerate(eachcol(values(data.coords[2]) |> flattenlast)) if !(c in sampled)]
else
(:)
end

chix = data.features[1] |> model |> vec |> cpu
chiy = data.features[2] |> flattenlast |> x -> getindex(x, :, selinds) |> model |> vec |> cpu

iy = resample_kde(chix, chiy, n; bandwidth)
iy = resample_kde_ash(chix, chiy, n)

ys = values(data.coords[2]) |> flattenlast |> x -> getindex(x, :, selinds)
newdata = addcoords(data, ys[:, iy])
Expand Down
24 changes: 24 additions & 0 deletions src/subsample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ end
to_pdf(f::Function) = f
to_pdf(d::Distributions.Distribution) = x -> Distributions.pdf(d, x)

import AverageShiftedHistograms

function kde_needles(xs, n=10; bandwidth, target=Distributions.Uniform())
xs = copy(xs)
needles = similar(xs, 0)
Expand All @@ -115,3 +117,25 @@ function kde_needles(xs, n=10; bandwidth, target=Distributions.Uniform())
end
return needles
end

function resample_kde_ash(xs, ys, n=10; m=50, target=Distributions.Uniform())
iys = zeros(Int, n)
rng = 0:0.001:1
kde = AverageShiftedHistograms.ash(xs; rng, m)
#display(kde)
target = to_pdf(target)(rng)
for i in 1:n
@show chi = rng[argmax(target - kde.density)] # position of maximal difference to target pdf
min = Inf
local iy
for j in 1:length(ys)
if abs(ys[j] - chi) < min && !(j in iys)
min = abs(ys[j] - chi)
iy = j
end
end
AverageShiftedHistograms.ash!(kde, ys[iy])
iys[i] = iy
end
return iys
end

0 comments on commit e9e1c36

Please sign in to comment.