Skip to content

Commit

Permalink
Merge pull request #50 from xukai92/change-default-turing
Browse files Browse the repository at this point in the history
Make Turing.jl default sampler to NUTS
  • Loading branch information
ChrisRackauckas authored Jul 10, 2018
2 parents 8e9548f + 5c35b4c commit d3282b2
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/turing_inference.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
function turing_inference(prob::DEProblem,alg,t,data,priors = nothing;
num_samples=1000, epsilon = 0.02, tau = 4, kwargs...)
num_samples=1000, delta=0.65, kwargs...)

bif(vi, sampler, x=data) = begin
_lp = 0.0
Expand Down Expand Up @@ -42,5 +42,5 @@ function turing_inference(prob::DEProblem,alg,t,data,priors = nothing;

bif() = bif(Turing.VarInfo(), nothing)

chn = sample(bif, Turing.HMC(num_samples, epsilon, tau))
chn = sample(bif, Turing.NUTS(num_samples, delta))
end
4 changes: 2 additions & 2 deletions test/turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ randomized = VectorOfArray([(sol(t[i]) + .01randn(2)) for i in 1:length(t)])
data = convert(Array,randomized)
priors = [Normal(1.5,0.01)]

bayesian_result = turing_inference(prob1,Tsit5(),t,data,priors;num_samples=500,epsilon = 0.001)
bayesian_result = turing_inference(prob1,Tsit5(),t,data,priors;num_samples=500)

@show mean(bayesian_result[:theta1][50:end])

Expand All @@ -36,7 +36,7 @@ data = convert(Array,randomized)
priors = [Truncated(Normal(1.5,0.01),0,2),Truncated(Normal(1.0,0.01),0,1.5),
Truncated(Normal(3.0,0.01),0,4),Truncated(Normal(1.0,0.01),0,2)]

bayesian_result = turing_inference(prob1,Tsit5(),t,data,priors;num_samples=500,epsilon = 0.001)
bayesian_result = turing_inference(prob1,Tsit5(),t,data,priors;num_samples=500)

@show mean(bayesian_result[:theta1][50:end])
@show mean(bayesian_result[:theta2][50:end])
Expand Down

0 comments on commit d3282b2

Please sign in to comment.