Skip to content

Commit

Permalink
Decouple rand and eltype
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion committed Sep 25, 2024
1 parent a1010e4 commit 0ea5502
Show file tree
Hide file tree
Showing 19 changed files with 166 additions and 36 deletions.
34 changes: 11 additions & 23 deletions src/genericrand.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,35 +30,23 @@ function rand(rng::AbstractRNG, s::Sampleable{<:ArrayLikeVariate})
end

# multiple samples
function rand(rng::AbstractRNG, s::Sampleable{Univariate}, dims::Dims)
out = Array{eltype(s)}(undef, dims)
return @inbounds rand!(rng, sampler(s), out)
# we use function barriers since for some distributions `sampler(s)` is not type-stable:
# https://github.com/JuliaStats/Distributions.jl/pull/1281
function rand(rng::AbstractRNG, s::Sampleable{<:ArrayLikeVariate}, dims::Dims)
return _rand(rng, sampler(s), dims)
end
function rand(
rng::AbstractRNG, s::Sampleable{<:ArrayLikeVariate}, dims::Dims,
)
sz = size(s)
ax = map(Base.OneTo, dims)
out = [Array{eltype(s)}(undef, sz) for _ in Iterators.product(ax...)]
return @inbounds rand!(rng, sampler(s), out, false)
function _rand(rng::AbstractRNG, s::Sampleable{<:ArrayLikeVariate}, dims::Dims)
r = rand(rng, s)
out = Array{typeof(r)}(undef, dims)
out[1] = r
rand!(rng, s, @view(out[2:end]))
return out
end

# these are workarounds for sampleables that incorrectly base `eltype` on the parameters
# this is a workaround for sampleables that incorrectly base `eltype` on the parameters
function rand(rng::AbstractRNG, s::Sampleable{<:ArrayLikeVariate,Continuous})
return @inbounds rand!(rng, sampler(s), Array{float(eltype(s))}(undef, size(s)))
end
function rand(rng::AbstractRNG, s::Sampleable{Univariate,Continuous}, dims::Dims)
out = Array{float(eltype(s))}(undef, dims)
return @inbounds rand!(rng, sampler(s), out)
end
function rand(
rng::AbstractRNG, s::Sampleable{<:ArrayLikeVariate,Continuous}, dims::Dims,
)
sz = size(s)
ax = map(Base.OneTo, dims)
out = [Array{float(eltype(s))}(undef, sz) for _ in Iterators.product(ax...)]
return @inbounds rand!(rng, sampler(s), out, false)
end

"""
rand!([rng::AbstractRNG,] s::Sampleable, A::AbstractArray)
Expand Down
9 changes: 9 additions & 0 deletions src/multivariate/dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,15 @@ end

# sampling

function rand(rng::AbstractRNG, d::Union{Dirichlet,DirichletCanon})
x = map(αi -> rand(rng, Gamma(αi)), d.alpha)
return lmul!(inv(sum(x)), x)
end
function rand(rng::AbstractRNG, d::Dirichlet{<:Real,<:FillArrays.AbstractFill{<:Real}})
x = rand(rng, Gamma(FillArrays.getindex_value(d.alpha)), length(d))
return lmul!(inv(sum(x)), x)
end

function _rand!(rng::AbstractRNG,
d::Union{Dirichlet,DirichletCanon},
x::AbstractVector{<:Real})
Expand Down
2 changes: 2 additions & 0 deletions src/multivariate/dirichletmultinomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ end


# Sampling
rand(rng::AbstractRNG, d::DirichletMultinomial) =
multinom_rand(rng, ntrials(d), rand(rng, Dirichlet(d.α)))
_rand!(rng::AbstractRNG, d::DirichletMultinomial, x::AbstractVector{<:Real}) =
multinom_rand!(rng, ntrials(d), rand(rng, Dirichlet(d.α)), x)

Expand Down
21 changes: 21 additions & 0 deletions src/multivariate/jointorderstatistics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,27 @@ function _marginalize_range(dist, i, j, xᵢ, xⱼ, T)
return k * T(logdiffcdf(dist, xⱼ, xᵢ)) - loggamma(T(k + 1))
end

function rand(rng::AbstractRNG, d::JointOrderStatistics)
n = d.n
if n == length(d.ranks) # ranks == 1:n
# direct method, slower than inversion method for large `n` and distributions with
# fast quantile function or that use inversion sampling
x = rand(rng, d.dist, n)
sort!(x)
else
# use exponential generation method with inversion, where for gaps in the ranks, we
# use the fact that the sum Y of k IID variables xₘ ~ Exp(1) is Y ~ Gamma(k, 1).
# Lurie, D., and H. O. Hartley. "Machine-generation of order statistics for Monte
# Carlo computations." The American Statistician 26.1 (1972): 26-27.
# this is slow if length(d.ranks) is close to n and quantile for d.dist is expensive,
# but this branch is probably taken when length(d.ranks) is small or much smaller than n.
xi = rand(rng, d.dist) # this is only used to obtain the type of samples from `d.dist`
x = Vector{typeof(xi)}(undef, length(d.ranks))
_rand!(rng, d, x)
end
return x
end

function _rand!(rng::AbstractRNG, d::JointOrderStatistics, x::AbstractVector{<:Real})
n = d.n
if n == length(d.ranks) # ranks == 1:n
Expand Down
1 change: 1 addition & 0 deletions src/multivariate/multinomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ end
# Sampling

# if only a single sample is requested, no alias table is created
rand(rng::AbstractRNG, d::Multinomial) = multinom_rand(rng, ntrials(d), probs(d))
_rand!(rng::AbstractRNG, d::Multinomial, x::AbstractVector{<:Real}) =
multinom_rand!(rng, ntrials(d), probs(d), x)

Expand Down
13 changes: 13 additions & 0 deletions src/multivariate/mvlogitnormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,19 @@ kldivergence(p::MvLogitNormal, q::MvLogitNormal) = kldivergence(p.normal, q.norm

# Sampling

function rand(rng::AbstractRNG, d::MvLogitNormal)
x = rand(rng, d.normal)
push!(x, zero(eltype(x)))
StatsFuns.softmax!(x)
return x
end
function rand(rng::AbstractRNG, d::MvLogitNormal, n::Int)
r = rand(rng, d.normal, n)
x = vcat(r, zeros(eltype(r), 1, n))
StatsFuns.softmax!(x; dims=1)
return x
end

function _rand!(rng::AbstractRNG, d::MvLogitNormal, x::AbstractVecOrMat{<:Real})
y = @views _drop1(x)
rand!(rng, d.normal, y)
Expand Down
11 changes: 11 additions & 0 deletions src/multivariate/mvlognormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,17 @@ var(d::MvLogNormal) = diag(cov(d))
entropy(d::MvLogNormal) = length(d)*(1+log2π)/2 + logdetcov(d.normal)/2 + sum(mean(d.normal))

#See https://en.wikipedia.org/wiki/Log-normal_distribution
function rand(rng::AbstractRNG, d::MvLogNormal)
x = rand(rng, d.normal)
map!(exp, x, x)
return x
end
function rand(rng::AbstractRNG, d::MvLogNormal, n::Int)
xs = rand(rng, d.normal, n)
map!(exp, xs, xs)
return xs
end

function _rand!(rng::AbstractRNG, d::MvLogNormal, x::AbstractVecOrMat{<:Real})
_rand!(rng, d.normal, x)
map!(exp, x, x)
Expand Down
11 changes: 11 additions & 0 deletions src/multivariate/mvnormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,17 @@ gradlogpdf(d::MvNormal, x::AbstractVector{<:Real}) = -(d.Σ \ (x .- d.μ))

# Sampling (for GenericMvNormal)

function rand(rng::AbstractRNG, d::MvNormal)
x = unwhiten!(d.Σ, randn(rng, float(partype(d)), length(d)))
x .+= d.μ
return x
end
function rand(rng::AbstractRNG, d::MvNormal, n::Int)
x = unwhiten!(d.Σ, randn(rng, float(partype(d)), length(d), n))
x .+= d.μ
return x
end

function _rand!(rng::AbstractRNG, d::MvNormal, x::VecOrMat)
unwhiten!(d.Σ, randn!(rng, x))
x .+= d.μ
Expand Down
11 changes: 11 additions & 0 deletions src/multivariate/mvnormalcanon.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,17 @@ if isdefined(PDMats, :PDSparseMat)
unwhiten_winv!(J::PDSparseMat, x::AbstractVecOrMat) = x[:] = J.chol.PtL' \ x
end

function rand(rng::AbstractRNG, d::MvNormalCanon)
x = unwhiten_winv!(d.J, randn(rng, float(partype(d)), length(d)))
x .+= d.μ
return x
end
function rand(rng::AbstractRNG, d::MvNormalCanon, n::Int)
x = unwhiten_winv!(d.J, randn(rng, float(partype(d)), length(d), n))
x .+= d.μ
return x
end

function _rand!(rng::AbstractRNG, d::MvNormalCanon, x::AbstractVector)
unwhiten_winv!(d.J, randn!(rng, x))
x .+= d.μ
Expand Down
15 changes: 15 additions & 0 deletions src/multivariate/mvtdist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,21 @@ function gradlogpdf(d::GenericMvTDist, x::AbstractVector{<:Real})
end

# Sampling (for GenericMvTDist)
function rand(rng::AbstractRNG, d::GenericMvTDist)
chisqd = Chisq{partype(d)}(d.df)
y = sqrt(rand(rng, chisqd) / d.df)
x = unwhiten!(d.Σ, randn(rng, typeof(y), length(d)))
x .= x ./ y .+ d.μ
x
end
function rand(rng::AbstractRNG, d::GenericMvTDist, n::Int)
chisqd = Chisq{partype(d)}(d.df)
y = rand(rng, chisqd, n)
x = unwhiten!(d.Σ, randn(rng, eltype(y), length(d), n))
x .= x ./ sqrt.(y' ./ d.df) .+ d.μ
x
end

function _rand!(rng::AbstractRNG, d::GenericMvTDist, x::AbstractVector{<:Real})
chisqd = Chisq{partype(d)}(d.df)
y = sqrt(rand(rng, chisqd) / d.df)
Expand Down
18 changes: 14 additions & 4 deletions src/multivariates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,20 @@ size(d::MultivariateDistribution)

# multiple multivariate, must allocate matrix
# TODO: inconsistency with other `ArrayLikeVariate`s and `rand(s, (n,))` - maybe remove?
rand(rng::AbstractRNG, s::Sampleable{Multivariate}, n::Int) =
@inbounds rand!(rng, sampler(s), Matrix{eltype(s)}(undef, length(s), n))
rand(rng::AbstractRNG, s::Sampleable{Multivariate,Continuous}, n::Int) =
@inbounds rand!(rng, sampler(s), Matrix{float(eltype(s))}(undef, length(s), n))
function rand(rng::AbstractRNG, s::Sampleable{Multivariate}, n::Int)
return _rand(rng, sampler(s), n)
end
function _rand(rng, s::Sampleable{Multivariate}, n::Int)
r = rand(rng, s)
out = Matrix{eltype(r)}(undef, length(r), n)
if n > 0
copyto!(out, r)
if n > 1
rand!(rng, s, @view(out[:, 2:n]))
end
end
return out
end

## domain

Expand Down
6 changes: 6 additions & 0 deletions src/samplers/multinomial.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
function multinom_rand(rng::AbstractRNG, n::Int, p::AbstractVector{<:Real})
return multinom_rand!(rng, n, p, Vector{Int}(undef, length(p)))
end
function multinom_rand!(rng::AbstractRNG, n::Int, p::AbstractVector{<:Real},
x::AbstractVector{<:Real})
k = length(p)
Expand Down Expand Up @@ -49,6 +52,9 @@ function MultinomialSampler(n::Int, prob::Vector{<:Real})
return MultinomialSampler(n, prob, AliasTable(prob))
end

function rand(rng::AbstractRNG, s::MultinomialSampler)
return _rand!(rng, s, Vector{Int}(undef, length(s.prob)))
end
function _rand!(rng::AbstractRNG, s::MultinomialSampler,
x::AbstractVector{<:Real})
n = s.n
Expand Down
2 changes: 1 addition & 1 deletion src/univariate/continuous/uniform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ Base.:*(c::Real, d::Uniform) = Uniform(minmax(c * d.a, c * d.b)...)

#### Sampling

rand(rng::AbstractRNG, d::Uniform) = d.a + (d.b - d.a) * rand(rng)
rand(rng::AbstractRNG, d::Uniform{T}) where {T} = d.a + (d.b - d.a) * rand(rng, float(T))

_rand!(rng::AbstractRNG, d::Uniform, A::AbstractArray{<:Real}) =
A .= Base.Fix1(quantile, d).(rand!(rng, A))
Expand Down
3 changes: 1 addition & 2 deletions src/univariate/orderstatistic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ end
function rand(rng::AbstractRNG, d::OrderStatistic)
# inverse transform sampling. Since quantile function is Qₓ(Uᵢₙ⁻¹(p)), we draw a random
# variable from Uᵢₙ and pass it through the quantile function of `d.dist`
T = eltype(d.dist)
b = _uniform_orderstatistic(d)
return T(quantile(d.dist, rand(rng, b)))
return quantile(d.dist, float(partype(d.dist))(rand(rng, b)))
end
4 changes: 3 additions & 1 deletion src/univariates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,9 @@ end
Generate a scalar sample from `d`. The general fallback is `quantile(d, rand())`.
"""
rand(rng::AbstractRNG, d::UnivariateDistribution) = quantile(d, rand(rng))
function rand(rng::AbstractRNG, d::UnivariateDistribution)
return quantile(d, rand(rng, float(partype(d))))
end

## statistics

Expand Down
14 changes: 12 additions & 2 deletions test/testutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,12 @@ function test_samples(s::Sampleable{Univariate, Discrete}, # the sampleable
samples3 = [rand(rng3, s) for _ in 1:n]
samples4 = [rand(rng4, s) for _ in 1:n]
end
@test length(samples) == n
T = typeof(rand(s))
@test samples isa Vector{T}
@test samples2 isa Vector{T}
@test samples3 isa Vector{T}
@test samples4 isa Vector{T}
@test length(samples) == length(samples2) == length(samples3) == length(samples4) == n
@test samples2 == samples
@test samples3 == samples4

Expand Down Expand Up @@ -289,7 +294,12 @@ function test_samples(s::Sampleable{Univariate, Continuous}, # the sampleable
samples3 = [rand(rng3, s) for _ in 1:n]
samples4 = [rand(rng4, s) for _ in 1:n]
end
@test length(samples) == n
T = typeof(rand(s))
@test samples isa Vector{T}
@test samples2 isa Vector{T}
@test samples3 isa Vector{T}
@test samples4 isa Vector{T}
@test length(samples) == length(samples2) == length(samples3) == length(samples4) == n
@test samples2 == samples
@test samples3 == samples4

Expand Down
17 changes: 15 additions & 2 deletions test/univariate/continuous/logistic.jl
Original file line number Diff line number Diff line change
@@ -1,2 +1,15 @@
test_cgf(Logistic(0, 1), (-0.99,0.99, 1f-2, -1f-2))
test_cgf(Logistic(100,10), (-0.099,0.099, 1f-2, -1f-2))
using Distributions
using Test

@testset "Logistic" begin
test_cgf(Logistic(0, 1), (-0.99,0.99, 1f-2, -1f-2))
test_cgf(Logistic(100,10), (-0.099,0.099, 1f-2, -1f-2))

# issue 1082
@testset "rand consistency" begin
for T in (Float32, Float64, BigFloat)
@test @inferred(rand(Logistic(T(0), T(1)))) isa T
@test @inferred(rand(Logistic(T(0), T(1)), 5)) isa Vector{T}
end
end
end
2 changes: 1 addition & 1 deletion test/univariate/continuous/tdist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ using Test
@inferred(rand(TDist(big"1.0")))
end
@inferred(rand(TDist(ForwardDiff.Dual(1.0))))

end

for T in (Float32, Float64)
@test @inferred(rand(TDist(T(1)))) isa T
@test @inferred(rand(TDist(T(1)), 5)) isa Vector{T}
end
end
8 changes: 8 additions & 0 deletions test/univariate/continuous/uniform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,4 +114,12 @@ using Test
end
end
end

# issues 1252 and 1783
@testset "rand consistency" begin
for T in (Float32, Float64, BigFloat)
@test @inferred(rand(Uniform(T(0), T(1)))) isa T
@test @inferred(rand(Uniform(T(0), T(1)), 5)) isa Vector{T}
end
end
end

0 comments on commit 0ea5502

Please sign in to comment.