diff --git a/src/Xymat.jl b/src/Xymat.jl index b85d81d8d..8928dc358 100644 --- a/src/Xymat.jl +++ b/src/Xymat.jl @@ -61,6 +61,15 @@ Base.copyto!(A::FeTerm{T}, src::AbstractVecOrMat{T}) where {T} = copyto!(A.x, sr Base.eltype(::FeTerm{T}) where {T} = T +""" + pivot(m::MixedModel) + pivot(A::FeTerm) + +Return the pivot associated with the FeTerm. +""" +@inline pivot(m::MixedModel) = pivot(m.feterm) +@inline pivot(A::FeTerm) = A.piv + function fullrankx(A::FeTerm) x, rnk = A.x, A.rank return rnk == size(x, 2) ? x : view(x, :, 1:rnk) # this handles the zero-columns case diff --git a/src/generalizedlinearmixedmodel.jl b/src/generalizedlinearmixedmodel.jl index c3c8b13be..66da6123e 100644 --- a/src/generalizedlinearmixedmodel.jl +++ b/src/generalizedlinearmixedmodel.jl @@ -54,7 +54,7 @@ struct GeneralizedLinearMixedModel{T<:AbstractFloat,D<:Distribution} <: MixedMod end function StatsAPI.coef(m::GeneralizedLinearMixedModel{T}) where {T} - piv = m.LMM.feterm.piv + piv = pivot(m) return invpermute!(copyto!(fill(T(-0.0), length(piv)), m.β), piv) end diff --git a/src/linearmixedmodel.jl b/src/linearmixedmodel.jl index 1a006ae68..cc95b1a1a 100644 --- a/src/linearmixedmodel.jl +++ b/src/linearmixedmodel.jl @@ -267,11 +267,11 @@ function StatsAPI.fit( end function StatsAPI.coef(m::LinearMixedModel{T}) where {T} - return coef!(Vector{T}(undef, length(m.feterm.piv)), m) + return coef!(Vector{T}(undef, length(pivot(m))), m) end function coef!(v::AbstractVector{Tv}, m::MixedModel{T}) where {Tv,T} - piv = m.feterm.piv + piv = pivot(m) return invpermute!(fixef!(v, m), piv) end @@ -1243,12 +1243,12 @@ function stderror!(v::AbstractVector{Tv}, m::LinearMixedModel{T}) where {Tv,T} scr[i] = true v[i] = s * norm(ldiv!(L, scr)) end - invpermute!(v, m.feterm.piv) + invpermute!(v, pivot(m)) return v end function StatsAPI.stderror(m::LinearMixedModel{T}) where {T} - return stderror!(similar(m.feterm.piv, T), m) + return stderror!(similar(pivot(m), T), m) end """ diff --git a/src/predict.jl b/src/predict.jl index d540af353..9accf9007 100644 --- a/src/predict.jl +++ b/src/predict.jl @@ -70,7 +70,7 @@ Similarly, offsets are also not supported for `GeneralizedLinearMixedModel`. function StatsAPI.predict( m::LinearMixedModel, newdata::Tables.ColumnTable; new_re_levels=:missing ) - return _predict(m, newdata, coef(m)[m.feterm.piv]; new_re_levels) + return _predict(m, newdata, coef(m)[pivot(m)]; new_re_levels) end function StatsAPI.predict( @@ -81,7 +81,7 @@ function StatsAPI.predict( ) type in (:linpred, :response) || throw(ArgumentError("Invalid value for type: $(type)")) # want pivoted but not truncated - y = _predict(m.LMM, newdata, coef(m)[m.feterm.piv]; new_re_levels) + y = _predict(m.LMM, newdata, coef(m)[pivot(m)]; new_re_levels) return type == :linpred ? y : broadcast!(Base.Fix1(linkinv, Link(m)), y, y) end @@ -126,7 +126,7 @@ function _predict(m::MixedModel{T}, newdata, β; new_re_levels) where {T} ytemp, lmm end - pivotmatch = mnew.feterm.piv[m.feterm.piv] + pivotmatch = mnew.feterm.piv[pivot(m)] grps = fnames(m) mul!(y, view(mnew.X, :, pivotmatch), β) # mnew.reterms for the correct Z matrices diff --git a/src/simulate.jl b/src/simulate.jl index 1bd0184bd..b129e915e 100644 --- a/src/simulate.jl +++ b/src/simulate.jl @@ -149,7 +149,7 @@ end function simulate!( rng::AbstractRNG, y::AbstractVector, m::LinearMixedModel{T}; β=m.β, σ=m.σ, θ=m.θ ) where {T} - length(β) == length(m.feterm.piv) || length(β) == m.feterm.rank || + length(β) == length(pivot(m)) || length(β) == m.feterm.rank || throw(ArgumentError("You must specify all (non-singular) βs")) β = convert(Vector{T}, β) @@ -157,8 +157,8 @@ function simulate!( θ = convert(Vector{T}, θ) isempty(θ) || setθ!(m, θ) - if length(β) ≠ length(m.feterm.piv) - β = invpermute!(copyto!(fill(-0.0, length(m.feterm.piv)), β), + if length(β) ≠ length(pivot(m)) + β = invpermute!(copyto!(fill(-0.0, length(pivot(m))), β), m.feterm.piv) end @@ -228,7 +228,7 @@ function _simulate!( θ, resp=nothing, ) where {T} - length(β) == length(m.feterm.piv) || length(β) == m.feterm.rank || + length(β) == length(pivot(m)) || length(β) == m.feterm.rank || throw(ArgumentError("You must specify all (non-singular) βs")) dispersion_parameter(m) || @@ -249,7 +249,7 @@ function _simulate!( if length(β) == length(m.feterm.piv) # unlike LMM, GLMM stores the truncated, pivoted vector directly - β = β[view(m.feterm.piv, 1:(m.feterm.rank))] + β = β[view(pivot(m), 1:(m.feterm.rank))] end fast = (length(m.θ) == length(m.optsum.final)) setpar! = fast ? setθ! : setβθ!