Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consistent eltype and allow to specify type in rand #1433

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions src/cholesky/lkjcholesky.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,6 @@ end
# Properties
# -----------------------------------------------------------------------------

Base.eltype(::Type{LKJCholesky{T}}) where {T} = T

function Base.size(d::LKJCholesky)
p = d.d
return (p, p)
Expand Down Expand Up @@ -150,15 +148,14 @@ end
# Sampling
# -----------------------------------------------------------------------------

function Base.rand(rng::AbstractRNG, d::LKJCholesky)
factors = Matrix{eltype(d)}(undef, size(d))
function Base.rand(rng::AbstractRNG, ::Type{T}, d::LKJCholesky) where {T}
factors = Matrix{T}(undef, size(d))
R = LinearAlgebra.Cholesky(factors, d.uplo, 0)
return _lkj_cholesky_onion_sampler!(rng, d, R)
end
function Base.rand(rng::AbstractRNG, d::LKJCholesky, dims::Dims)
function Base.rand(rng::AbstractRNG, ::Type{T}, d::LKJCholesky, dims::Dims) where {T}
p = d.d
uplo = d.uplo
T = eltype(d)
TM = Matrix{T}
Rs = Array{LinearAlgebra.Cholesky{T,TM}}(undef, dims)
for i in eachindex(Rs)
Expand Down
7 changes: 4 additions & 3 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,10 @@ Base.size(s::Sampleable{Multivariate}) = (length(s),)
"""
eltype(::Type{Sampleable})

The default element type of a sample. This is the type of elements of the samples generated
by the `rand` method. However, one can provide an array of different element types to
store the samples using `rand!`.
The default element type of a sample.

This is the type of elements of the samples generated by the `rand` method. However, one can
provide an array of different element types to store the samples using `rand!`.
"""
Base.eltype(::Type{<:Sampleable{F,Discrete}}) where {F} = Int
Base.eltype(::Type{<:Sampleable{F,Continuous}}) where {F} = Float64
Expand Down
47 changes: 29 additions & 18 deletions src/genericrand.jl
Original file line number Diff line number Diff line change
@@ -1,30 +1,42 @@
### Generic rand methods

"""
rand([rng::AbstractRNG,] s::Sampleable)
rand(rng::AbstractRNG=GLOBAL_RNG, ::Type{T}=eltype(s), s::Sampleable)

Generate one sample for `s`.
Generate one sample for `s` of elment type `T`.
"""
rand(s::Sampleable) = rand(eltype(s), s)
rand(::Type{T}, s::Sampleable) where {T} = rand(GLOBAL_RNG, T, s)

rand([rng::AbstractRNG,] s::Sampleable, n::Int)
"""
rand(rng::AbstractRNG=GLOBAL_RNG, ::Type{T}=eltype(s), s::Sampleable, n::Int)

Generate `n` samples from `s`. The form of the returned object depends on the variate form of `s`:
Generate `n` samples from `s` of element type `T`.

The form of the returned object depends on the variate form of `s`:
- When `s` is univariate, it returns a vector of length `n`.
- When `s` is multivariate, it returns a matrix with `n` columns.
- When `s` is matrix-variate, it returns an array, where each element is a sample matrix.
"""
rand(s::Sampleable, n::Int) = rand(eltype(s), s, n)
rand(::Type{T}, s::Sampleable, n::Int) where {T} = rand(GLOBAL_RNG, T, s, n)
rand(rng::AbstractRNG, ::Type{T}, s::Sampleable, n::Int) where {T} = rand(rng, T, s, (n,))

rand([rng::AbstractRNG,] s::Sampleable, dim1::Int, dim2::Int...)
rand([rng::AbstractRNG,] s::Sampleable, dims::Dims)
"""
rand(rng::AbstractRNG=GLOBAL_RNG, ::Type{T}=eltype(s), s::Sampleable, dims::Int...)
rand(rng::AbstractRNG=GLOBAL_RNG, ::Type{T}=eltype(s), s::Sampleable, dims::Dims)

Generate an array of samples from `s` whose shape is determined by the given
Generate an array of samples from `s` of element type `T` whose shape is determined by the given
dimensions.
"""
rand(s::Sampleable) = rand(GLOBAL_RNG, s)
rand(s::Sampleable, dims::Dims) = rand(GLOBAL_RNG, s, dims)
rand(s::Sampleable, dim1::Int, moredims::Int...) =
rand(GLOBAL_RNG, s, (dim1, moredims...))
rand(rng::AbstractRNG, s::Sampleable, dim1::Int, moredims::Int...) =
rand(rng, s, (dim1, moredims...))
rand(s::Sampleable, dims::Dims) = rand(eltype(s), s, dims)
function rand(s::Sampleable, dims1::Int, dims2::Int, dims::Int...)
return rand(eltype(s), s, dims1, dims2, dims...)
end
function rand(::Type{T}, s::Sampleable, dims1::Int, dims2::Int, dims::Int...) where {T}
return rand(T, s, (dims1, dims2, dims...))
end
rand(::Type{T}, s::Sampleable, dims::Dims) where {T} = rand(GLOBAL_RNG, T, s, dims)

"""
rand!([rng::AbstractRNG,] s::Sampleable, A::AbstractArray)
Expand All @@ -46,13 +58,12 @@ rand!(s::Sampleable, X::AbstractArray) = rand!(GLOBAL_RNG, s, X)
rand!(rng::AbstractRNG, s::Sampleable, X::AbstractArray) = _rand!(rng, s, X)

"""
sampler(d::Distribution) -> Sampleable
sampler(s::Sampleable) -> s
sampler(s::Sampleable)

Return a sampler that is used for batch sampling.

Samplers can often rely on pre-computed quantities (that are not parameters
themselves) to improve efficiency. If such a sampler exists, it can be provided
with this `sampler` method, which would be used for batch sampling.
The general fallback is `sampler(d::Distribution) = d`.
themselves) to improve efficiency. The general fallback is `sampler(s) = s`.
"""
sampler(s::Sampleable) = s

Expand Down
14 changes: 6 additions & 8 deletions src/matrixvariates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,16 +128,14 @@ function rand!(rng::AbstractRNG, s::Sampleable{Matrixvariate},
end

# multiple matrix-variates, must allocate array of arrays
rand(rng::AbstractRNG, s::Sampleable{Matrixvariate}, dims::Dims) =
rand!(rng, s, Array{Matrix{eltype(s)}}(undef, dims), true)
rand(rng::AbstractRNG, s::Sampleable{Matrixvariate,Continuous}, dims::Dims) =
rand!(rng, s, Array{Matrix{float(eltype(s))}}(undef, dims), true)
function rand(rng::AbstractRNG, ::Type{T}, s::Sampleable{Matrixvariate}, dims::Dims) where {T}
return rand!(rng, s, Array{Matrix{T}}(undef, dims), true)
end

# single matrix-variate, must allocate one matrix
rand(rng::AbstractRNG, s::Sampleable{Matrixvariate}) =
_rand!(rng, s, Matrix{eltype(s)}(undef, size(s)))
rand(rng::AbstractRNG, s::Sampleable{Matrixvariate,Continuous}) =
_rand!(rng, s, Matrix{float(eltype(s))}(undef, size(s)))
function rand(rng::AbstractRNG, ::Type{T}, s::Sampleable{Matrixvariate}) where {T}
return _rand!(rng, s, Matrix{T}(undef, size(s)))
end

# single matrix-variate with pre-allocated matrix
function rand!(rng::AbstractRNG, s::Sampleable{Matrixvariate},
Expand Down
10 changes: 6 additions & 4 deletions src/mixtures/mixturemodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -468,10 +468,12 @@ function MixtureSampler(d::MixtureModel{VF,VS}) where {VF,VS}
MixtureSampler{VF,VS,eltype(csamplers)}(csamplers, psampler)
end

rand(rng::AbstractRNG, s::MixtureSampler{Univariate}) =
rand(rng, s.csamplers[rand(rng, s.psampler)])
rand(rng::AbstractRNG, d::MixtureModel{Univariate}) =
rand(rng, component(d, rand(rng, d.prior)))
function rand(rng::AbstractRNG, ::Type{T}, s::MixtureSampler{Univariate}) where {T}
return rand(rng, T, s.csamplers[rand(rng, s.psampler)])
end
function rand(rng::AbstractRNG, ::Type{T}, d::MixtureModel{Univariate}) where {T}
return rand(rng, T, component(d, rand(rng, d.prior)))
end

# multivariate mixture sampler for a vector
_rand!(rng::AbstractRNG, s::MixtureSampler{Multivariate}, x::AbstractVector) =
Expand Down
15 changes: 9 additions & 6 deletions src/mixtures/unigmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ probs(d::UnivariateGMM) = probs(d.prior)

mean(d::UnivariateGMM) = dot(d.means, probs(d))

rand(d::UnivariateGMM) = (k = rand(d.prior); d.means[k] + randn() * d.stds[k])

rand(rng::AbstractRNG, d::UnivariateGMM) =
(k = rand(rng, d.prior); d.means[k] + randn(rng) * d.stds[k])
function rand(rng::AbstractRNG, ::Type{T}, d::UnivariateGMM) where {T}
k = rand(rng, d.prior)
return T(d.means[k]) + randn(rng, T) * T(d.stds[k])
end

params(d::UnivariateGMM) = (d.means, d.stds, d.prior)

Expand All @@ -38,6 +38,9 @@ struct UnivariateGMMSampler{VT1<:AbstractVector{<:Real},VT2<:AbstractVector{<:Re
psampler::AliasTable
end

rand(rng::AbstractRNG, s::UnivariateGMMSampler) =
(k = rand(rng, s.psampler); s.means[k] + randn(rng) * s.stds[k])
function rand(rng::AbstractRNG, ::Type{T}, s::UnivariateGMMSampler) where {T}
k = rand(rng, s.psampler)
return T(s.means[k]) + randn(rng, T) * T(s.stds[k])
end

sampler(d::UnivariateGMM) = UnivariateGMMSampler(d.means, d.stds, sampler(d.prior))
2 changes: 0 additions & 2 deletions src/multivariate/dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ end

length(d::DirichletCanon) = length(d.alpha)

Base.eltype(::Type{<:Dirichlet{T}}) where {T} = T

#### Conversions
convert(::Type{Dirichlet{T}}, cf::DirichletCanon) where {T<:Real} =
Dirichlet(convert(AbstractVector{T}, cf.alpha))
Expand Down
2 changes: 0 additions & 2 deletions src/multivariate/mvlognormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,6 @@ MvLogNormal(μ::AbstractVector,s::Real) = MvLogNormal(MvNormal(μ,s))
MvLogNormal(σ::AbstractVector) = MvLogNormal(MvNormal(σ))
MvLogNormal(d::Int,s::Real) = MvLogNormal(MvNormal(d,s))

Base.eltype(::Type{<:MvLogNormal{T}}) where {T} = T

### Conversion
function convert(::Type{MvLogNormal{T}}, d::MvLogNormal) where T<:Real
MvLogNormal(convert(MvNormal{T}, d.normal))
Expand Down
6 changes: 2 additions & 4 deletions src/multivariate/mvnormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ abstract type AbstractMvNormal <: ContinuousMultivariateDistribution end
insupport(d::AbstractMvNormal, x::AbstractVector) =
length(d) == length(x) && all(isfinite, x)

minimum(d::AbstractMvNormal) = fill(eltype(d)(-Inf), length(d))
maximum(d::AbstractMvNormal) = fill(eltype(d)(Inf), length(d))
minimum(d::AbstractMvNormal) = fill(-Inf, length(d))
maximum(d::AbstractMvNormal) = fill(Inf, length(d))
mode(d::AbstractMvNormal) = mean(d)
modes(d::AbstractMvNormal) = [mean(d)]

Expand Down Expand Up @@ -222,8 +222,6 @@ Base.@deprecate MvNormal(μ::AbstractVector{<:Real}, σ::Real) MvNormal(μ, σ^2
Base.@deprecate MvNormal(σ::AbstractVector{<:Real}) MvNormal(LinearAlgebra.Diagonal(map(abs2, σ)))
Base.@deprecate MvNormal(d::Int, σ::Real) MvNormal(LinearAlgebra.Diagonal(FillArrays.Fill(σ^2, d)))

Base.eltype(::Type{<:MvNormal{T}}) where {T} = T

### Conversion
function convert(::Type{MvNormal{T}}, d::MvNormal) where T<:Real
MvNormal(convert(AbstractArray{T}, d.μ), convert(AbstractArray{T}, d.Σ))
Expand Down
1 change: 0 additions & 1 deletion src/multivariate/mvnormalcanon.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ length(d::MvNormalCanon) = length(d.μ)
mean(d::MvNormalCanon) = convert(Vector{eltype(d.μ)}, d.μ)
params(d::MvNormalCanon) = (d.μ, d.h, d.J)
@inline partype(d::MvNormalCanon{T}) where {T<:Real} = T
Base.eltype(::Type{<:MvNormalCanon{T}}) where {T} = T

var(d::MvNormalCanon) = diag(inv(d.J))
cov(d::MvNormalCanon) = Matrix(inv(d.J))
Expand Down
3 changes: 1 addition & 2 deletions src/multivariate/mvtdist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,7 @@ invcov(d::GenericMvTDist) = d.df>2 ? ((d.df-2)/d.df)*Matrix(inv(d.Σ)) : NaN*one
logdet_cov(d::GenericMvTDist) = d.df>2 ? logdet((d.df/(d.df-2))*d.Σ) : NaN

params(d::GenericMvTDist) = (d.df, d.μ, d.Σ)
@inline partype(d::GenericMvTDist{T}) where {T} = T
Base.eltype(::Type{<:GenericMvTDist{T}}) where {T} = T
@inline partype(::GenericMvTDist{T}) where {T} = T

# For entropy calculations see "Multivariate t Distributions and their Applications", S. Kotz & S. Nadarajah
function entropy(d::GenericMvTDist)
Expand Down
22 changes: 9 additions & 13 deletions src/multivariates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,21 +66,17 @@ function rand!(rng::AbstractRNG, s::Sampleable{Multivariate},
end

# multiple multivariate, must allocate matrix or array of vectors
rand(s::Sampleable{Multivariate}, n::Int) = rand(GLOBAL_RNG, s, n)
rand(rng::AbstractRNG, s::Sampleable{Multivariate}, n::Int) =
_rand!(rng, s, Matrix{eltype(s)}(undef, length(s), n))
rand(rng::AbstractRNG, s::Sampleable{Multivariate,Continuous}, n::Int) =
_rand!(rng, s, Matrix{float(eltype(s))}(undef, length(s), n))
rand(rng::AbstractRNG, s::Sampleable{Multivariate}, dims::Dims) =
rand!(rng, s, Array{Vector{eltype(s)}}(undef, dims), true)
rand(rng::AbstractRNG, s::Sampleable{Multivariate,Continuous}, dims::Dims) =
rand!(rng, s, Array{Vector{float(eltype(s))}}(undef, dims), true)
function rand(rng::AbstractRNG, ::Type{T}, s::Sampleable{Multivariate,Continuous}, n::Int) where {T}
return _rand!(rng, s, Matrix{T}(undef, length(s), n))
end
function rand(rng::AbstractRNG, ::Type{T}, s::Sampleable{Multivariate}, dims::Dims) where {T}
return rand!(rng, s, Array{Vector{T}}(undef, dims), true)
end

# single multivariate, must allocate vector
rand(rng::AbstractRNG, s::Sampleable{Multivariate}) =
_rand!(rng, s, Vector{eltype(s)}(undef, length(s)))
rand(rng::AbstractRNG, s::Sampleable{Multivariate,Continuous}) =
_rand!(rng, s, Vector{float(eltype(s))}(undef, length(s)))
function rand(rng::AbstractRNG, ::Type{T}, s::Sampleable{Multivariate}) where {T}
return _rand!(rng, s, Vector{T}(undef, length(s)))
end

## domain

Expand Down
4 changes: 2 additions & 2 deletions src/samplers/aliastable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ function AliasTable(probs::AbstractVector)
AliasTable(accp, alias)
end

function rand(rng::AbstractRNG, s::AliasTable)
function rand(rng::AbstractRNG, ::Type{T}, s::AliasTable) where {T}
i = rand(rng, 1:length(s.alias)) % Int
u = rand(rng)
@inbounds r = u < s.accept[i] ? i : s.alias[i]
r
return T(r)
end

show(io::IO, s::AliasTable) = @printf(io, "AliasTable with %d entries", ncategories(s))
33 changes: 17 additions & 16 deletions src/samplers/binomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,23 +44,23 @@ function BinomialGeomSampler(n::Int, prob::Float64)
BinomialGeomSampler(comp, n, scale)
end

function rand(rng::AbstractRNG, s::BinomialGeomSampler)
y = 0
x = 0
n = s.n
function rand(rng::AbstractRNG, ::Type{T}, s::BinomialGeomSampler) where {T}
y = zero(T)
x = zero(T)
n = T(s.n)
while true
er = randexp(rng)
v = er * s.scale
if v > n # in case when v is very large or infinity
break
end
y += ceil(Int,v)
y += T(ceil(v))
if y > n
break
end
x += 1
end
(s.comp ? s.n - x : x)::Int
return s.comp ? n - x : x
end


Expand Down Expand Up @@ -129,14 +129,14 @@ function BinomialTPESampler(n::Int, prob::Float64)
xM,xL,xR,c,λL,λR)
end

function rand(rng::AbstractRNG, s::BinomialTPESampler)
y = 0
function rand(rng::AbstractRNG, ::Type{T}, s::BinomialTPESampler) where {T}
y = zero(T)
while true
# Step 1
u = s.p4*rand(rng)
v = rand(rng)
if u <= s.p1
y = floor(Int,s.xM-s.p1*v+u)
y = T(floor(s.xM-s.p1*v+u))
# Goto 6
break
elseif u <= s.p2 # Step 2
Expand All @@ -146,18 +146,18 @@ function rand(rng::AbstractRNG, s::BinomialTPESampler)
# Goto 1
continue
end
y = floor(Int,x)
y = T(floor(x))
# Goto 5
elseif u <= s.p3 # Step 3
y = floor(Int,s.xL + log(v)/s.λL)
y = T(floor(s.xL + log(v)/s.λL))
if y < 0
# Goto 1
continue
end
v *= (u-s.p2)*s.λL
# Goto 5
else # Step 4
y = floor(Int,s.xR-log(v)/s.λR)
y = T(floor(s.xR-log(v)/s.λR))
if y > s.n
# Goto 1
continue
Expand Down Expand Up @@ -219,7 +219,7 @@ function rand(rng::AbstractRNG, s::BinomialTPESampler)
end
end
# 6
(s.comp ? s.n - y : y)::Int
return s.comp ? T(s.n) - y : y
end


Expand All @@ -231,7 +231,7 @@ end

BinomialAliasSampler(n::Int, p::Float64) = BinomialAliasSampler(AliasTable(binompvec(n, p)))

rand(rng::AbstractRNG, s::BinomialAliasSampler) = rand(rng, s.table) - 1
rand(rng::AbstractRNG, ::Type{T}, s::BinomialAliasSampler) where {T} = rand(rng, T, s.table) - 1


# Integrated Polyalgorithm sampler that automatically chooses the proper one
Expand Down Expand Up @@ -260,5 +260,6 @@ end

BinomialPolySampler(n::Real, p::Real) = BinomialPolySampler(round(Int, n), Float64(p))

rand(rng::AbstractRNG, s::BinomialPolySampler) =
s.use_btpe ? rand(rng, s.btpe_sampler) : rand(rng, s.geom_sampler)
function rand(rng::AbstractRNG, ::Type{T}, s::BinomialPolySampler) where {T}
return rand(rng, T, s.use_btpe ? s.btpe_sampler : s.geom_sampler)
end
6 changes: 3 additions & 3 deletions src/samplers/categorical.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ CategoricalDirectSampler(p::Ts) where {T<:Real,Ts<:AbstractVector{T}} =

ncategories(s::CategoricalDirectSampler) = length(s.prob)

function rand(rng::AbstractRNG, s::CategoricalDirectSampler)
function rand(rng::AbstractRNG, ::Type{T}, s::CategoricalDirectSampler) where {T}
p = s.prob
n = length(p)
i = 1
c = p[1]
u = rand(rng)
u = rand(rng, typeof(float(c)))
while c < u && i < n
c += p[i += 1]
end
return i
return T(i)
end
Loading