Skip to content

Commit

Permalink
add pivot accessor
Browse files Browse the repository at this point in the history
  • Loading branch information
palday committed Mar 25, 2024
1 parent 41609b5 commit 2d74d0a
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 13 deletions.
9 changes: 9 additions & 0 deletions src/Xymat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@ Base.copyto!(A::FeTerm{T}, src::AbstractVecOrMat{T}) where {T} = copyto!(A.x, sr

Base.eltype(::FeTerm{T}) where {T} = T

"""
pivot(m::MixedModel)
pivot(A::FeTerm)
Return the pivot associated with the FeTerm.
"""
@inline pivot(m::MixedModel) = pivot(m.feterm)
@inline pivot(A::FeTerm) = A.piv

function fullrankx(A::FeTerm)
x, rnk = A.x, A.rank
return rnk == size(x, 2) ? x : view(x, :, 1:rnk) # this handles the zero-columns case
Expand Down
2 changes: 1 addition & 1 deletion src/generalizedlinearmixedmodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ struct GeneralizedLinearMixedModel{T<:AbstractFloat,D<:Distribution} <: MixedMod
end

function StatsAPI.coef(m::GeneralizedLinearMixedModel{T}) where {T}
piv = m.LMM.feterm.piv
piv = pivot(m)
return invpermute!(copyto!(fill(T(-0.0), length(piv)), m.β), piv)
end

Expand Down
8 changes: 4 additions & 4 deletions src/linearmixedmodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -267,11 +267,11 @@ function StatsAPI.fit(
end

function StatsAPI.coef(m::LinearMixedModel{T}) where {T}
return coef!(Vector{T}(undef, length(m.feterm.piv)), m)
return coef!(Vector{T}(undef, length(pivot(m))), m)
end

function coef!(v::AbstractVector{Tv}, m::MixedModel{T}) where {Tv,T}
piv = m.feterm.piv
piv = pivot(m)
return invpermute!(fixef!(v, m), piv)
end

Expand Down Expand Up @@ -1243,12 +1243,12 @@ function stderror!(v::AbstractVector{Tv}, m::LinearMixedModel{T}) where {Tv,T}
scr[i] = true
v[i] = s * norm(ldiv!(L, scr))
end
invpermute!(v, m.feterm.piv)
invpermute!(v, pivot(m))
return v
end

function StatsAPI.stderror(m::LinearMixedModel{T}) where {T}
return stderror!(similar(m.feterm.piv, T), m)
return stderror!(similar(pivot(m), T), m)
end

"""
Expand Down
6 changes: 3 additions & 3 deletions src/predict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ Similarly, offsets are also not supported for `GeneralizedLinearMixedModel`.
function StatsAPI.predict(
m::LinearMixedModel, newdata::Tables.ColumnTable; new_re_levels=:missing
)
return _predict(m, newdata, coef(m)[m.feterm.piv]; new_re_levels)
return _predict(m, newdata, coef(m)[pivot(m)]; new_re_levels)
end

function StatsAPI.predict(
Expand All @@ -81,7 +81,7 @@ function StatsAPI.predict(
)
type in (:linpred, :response) || throw(ArgumentError("Invalid value for type: $(type)"))
# want pivoted but not truncated
y = _predict(m.LMM, newdata, coef(m)[m.feterm.piv]; new_re_levels)
y = _predict(m.LMM, newdata, coef(m)[pivot(m)]; new_re_levels)

return type == :linpred ? y : broadcast!(Base.Fix1(linkinv, Link(m)), y, y)
end
Expand Down Expand Up @@ -126,7 +126,7 @@ function _predict(m::MixedModel{T}, newdata, β; new_re_levels) where {T}
ytemp, lmm
end

pivotmatch = mnew.feterm.piv[m.feterm.piv]
pivotmatch = mnew.feterm.piv[pivot(m)]
grps = fnames(m)
mul!(y, view(mnew.X, :, pivotmatch), β)
# mnew.reterms for the correct Z matrices
Expand Down
10 changes: 5 additions & 5 deletions src/simulate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,16 +149,16 @@ end
function simulate!(
rng::AbstractRNG, y::AbstractVector, m::LinearMixedModel{T}; β=m.β, σ=m.σ, θ=m.θ
) where {T}
length(β) == length(m.feterm.piv) || length(β) == m.feterm.rank ||
length(β) == length(pivot(m)) || length(β) == m.feterm.rank ||
throw(ArgumentError("You must specify all (non-singular) βs"))

β = convert(Vector{T}, β)
σ = T(σ)
θ = convert(Vector{T}, θ)
isempty(θ) || setθ!(m, θ)

if length(β) length(m.feterm.piv)
β = invpermute!(copyto!(fill(-0.0, length(m.feterm.piv)), β),
if length(β) length(pivot(m))
β = invpermute!(copyto!(fill(-0.0, length(pivot(m))), β),
m.feterm.piv)
end

Expand Down Expand Up @@ -228,7 +228,7 @@ function _simulate!(
θ,
resp=nothing,
) where {T}
length(β) == length(m.feterm.piv) || length(β) == m.feterm.rank ||
length(β) == length(pivot(m)) || length(β) == m.feterm.rank ||
throw(ArgumentError("You must specify all (non-singular) βs"))

dispersion_parameter(m) ||
Expand All @@ -249,7 +249,7 @@ function _simulate!(

if length(β) == length(m.feterm.piv)
# unlike LMM, GLMM stores the truncated, pivoted vector directly
β = β[view(m.feterm.piv, 1:(m.feterm.rank))]
β = β[view(pivot(m), 1:(m.feterm.rank))]
end
fast = (length(m.θ) == length(m.optsum.final))
setpar! = fast ? setθ! : setβθ!
Expand Down

0 comments on commit 2d74d0a

Please sign in to comment.