From 510dcc3cb323d2575ff58adff771bda1bfbf62bb Mon Sep 17 00:00:00 2001 From: Phillip Alday Date: Tue, 5 Mar 2024 13:35:05 -0600 Subject: [PATCH] use FMA where possible in fitting (#740) * use FMA where possible in fitting * use muladd everywhere * NEWS update * format --- NEWS.md | 2 ++ src/linalg.jl | 2 +- src/linalg/rankUpdate.jl | 18 +++++++++--------- src/linearmixedmodel.jl | 6 ++++-- src/remat.jl | 23 +++++++++++++---------- test/pls.jl | 2 +- 6 files changed, 30 insertions(+), 23 deletions(-) diff --git a/NEWS.md b/NEWS.md index 4a30410f3..e880cbc99 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,6 @@ MixedModels v4.22.5 Release Notes ============================== +* Use `muladd` where possible to enable fused multiply-add (FMA) on architectures with hardware support. FMA will generally improve computational speed and gives more accurate rounding. [#740] * Replace broadcasted lambda with explicit loop and use `one`. This may result in a small performance improvement. [#738] MixedModels v4.22.4 Release Notes @@ -500,5 +501,6 @@ Package dependencies [#717]: https://github.com/JuliaStats/MixedModels.jl/issues/717 [#733]: https://github.com/JuliaStats/MixedModels.jl/issues/733 [#738]: https://github.com/JuliaStats/MixedModels.jl/issues/738 +[#740]: https://github.com/JuliaStats/MixedModels.jl/issues/740 [#744]: https://github.com/JuliaStats/MixedModels.jl/issues/744 [#748]: https://github.com/JuliaStats/MixedModels.jl/issues/748 diff --git a/src/linalg.jl b/src/linalg.jl index 5cf62bce9..91e628153 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -18,7 +18,7 @@ function LinearAlgebra.mul!( αbnz = α * bnz[ib] jj = brv[ib] for ia in nzrange(A, j) - C[arv[ia], jj] += anz[ia] * αbnz + C[arv[ia], jj] = muladd(anz[ia], αbnz, C[arv[ia], jj]) end end end diff --git a/src/linalg/rankUpdate.jl b/src/linalg/rankUpdate.jl index 1d215cb17..e6f1aacd6 100644 --- a/src/linalg/rankUpdate.jl +++ b/src/linalg/rankUpdate.jl @@ -22,7 +22,7 @@ function MixedModels.rankUpdate!( Cdiag = C.data.diag Adiag = A.diag @inbounds for idx in eachindex(Cdiag, Adiag) - Cdiag[idx] = β * Cdiag[idx] + α * abs2(Adiag[idx]) + Cdiag[idx] = muladd(β, Cdiag[idx], α * abs2(Adiag[idx])) end return C end @@ -52,7 +52,7 @@ function _columndot(rv, nz, rngi, rngj) while i ≤ ni && j ≤ nj @inbounds ri, rj = rv[rngi[i]], rv[rngj[j]] if ri == rj - @inbounds accum += nz[rngi[i]] * nz[rngj[j]] + @inbounds accum = muladd(nz[rngi[i]], nz[rngj[j]], accum) i += 1 j += 1 elseif ri < rj @@ -80,7 +80,7 @@ function rankUpdate!(C::HermOrSym{T,S}, A::SparseMatrixCSC{T}, α, β) where {T, rvj = rv[j] for i in k:lenrngjj kk = rangejj[i] - Cd[rv[kk], rvj] += nz[kk] * anzj + Cd[rv[kk], rvj] = muladd(nz[kk], anzj, Cd[rv[kk], rvj]) end end end @@ -88,9 +88,9 @@ function rankUpdate!(C::HermOrSym{T,S}, A::SparseMatrixCSC{T}, α, β) where {T, @inbounds for j in axes(C, 2) rngj = nzrange(A, j) for i in 1:(j - 1) - Cd[i, j] += α * _columndot(rv, nz, nzrange(A, i), rngj) + Cd[i, j] = muladd(α, _columndot(rv, nz, nzrange(A, i), rngj), Cd[i, j]) end - Cd[j, j] += α * sum(i -> abs2(nz[i]), rngj) + Cd[j, j] = muladd(α, sum(i -> abs2(nz[i]), rngj), Cd[j, j]) end end return C @@ -109,7 +109,7 @@ function rankUpdate!( isone(β) || rmul!(Cdiag, β) @inbounds for i in eachindex(Cdiag) - Cdiag[i] += α * sum(abs2, view(A, i, :)) + Cdiag[i] = muladd(α, sum(abs2, view(A, i, :)), Cdiag[i]) end return C @@ -132,9 +132,9 @@ function rankUpdate!( AtAij = 0 for idx in axes(A, 2) # because the second multiplicant is from A', swap index order - AtAij += A[iind, idx] * A[jind, idx] + AtAij = muladd(A[iind, idx], A[jind, idx], AtAij) end - Cdat[i, j, k] += α * AtAij + Cdat[i, j, k] = muladd(α, AtAij, Cdat[i, j, k]) end end @@ -152,7 +152,7 @@ function rankUpdate!( throw(ArgumentError("Columns of A must have exactly 1 nonzero")) for (r, nz) in zip(rowvals(A), nonzeros(A)) - dd[r] += α * abs2(nz) + dd[r] = muladd(α, abs2(nz), dd[r]) end return C diff --git a/src/linearmixedmodel.jl b/src/linearmixedmodel.jl index 74251d6b7..3827f41ec 100644 --- a/src/linearmixedmodel.jl +++ b/src/linearmixedmodel.jl @@ -767,7 +767,9 @@ function StatsAPI.leverage(m::LinearMixedModel{T}) where {T} z = trm.z stride = size(z, 1) mul!( - view(rhs2, (rhsoffset + (trm.refs[i] - 1) * stride) .+ Base.OneTo(stride)), + view( + rhs2, muladd((trm.refs[i] - 1), stride, rhsoffset) .+ Base.OneTo(stride) + ), adjoint(trm.λ), view(z, :, i), ) @@ -816,7 +818,7 @@ function objective(m::LinearMixedModel{T}) where {T} val = if isnothing(σ) logdet(m) + denomdf * (one(T) + log2π + log(pwrss(m) / denomdf)) else - denomdf * (log2π + 2 * log(σ)) + logdet(m) + pwrss(m) / σ^2 + muladd(denomdf, muladd(2, log(σ), log2π), (logdet(m) + pwrss(m) / σ^2)) end return isempty(wts) ? val : val - T(2.0) * sum(log, wts) end diff --git a/src/remat.jl b/src/remat.jl index b350cd059..74d92ff88 100644 --- a/src/remat.jl +++ b/src/remat.jl @@ -284,7 +284,7 @@ function LinearAlgebra.mul!( @inbounds for (j, rrj) in enumerate(B.refs) αzj = α * zz[j] for i in 1:p - C[i, rrj] += αzj * Awt[j, i] + C[i, rrj] = muladd(αzj, Awt[j, i], C[i, rrj]) end end return C @@ -310,7 +310,7 @@ function LinearAlgebra.mul!( aki = α * Awt[k, i] kk = Int(rr[k]) for ii in 1:S - scr[ii, kk] += aki * Bwt[ii, k] + scr[ii, kk] = muladd(aki, Bwt[ii, k], scr[ii, kk]) end end for j in 1:q @@ -340,7 +340,7 @@ function LinearAlgebra.mul!( coljlast = Int(C.colptr[j + 1] - 1) K = searchsortedfirst(rv, i, Int(C.colptr[j]), coljlast, Base.Order.Forward) if K ≤ coljlast && rv[K] == i - nz[K] += Az[k] * Bz[k] + nz[K] = muladd(Az[k], Bz[k], nz[K]) else throw(ArgumentError("C does not have the nonzero pattern of A'B")) end @@ -361,7 +361,7 @@ function LinearAlgebra.mul!( @inbounds for i in 1:S zij = Awtz[i, j] for k in 1:S - Cd[k, i, r] += zij * Awtz[k, j] + Cd[k, i, r] = muladd(zij, Awtz[k, j], Cd[k, i, r]) end end end @@ -397,7 +397,7 @@ function LinearAlgebra.mul!( jjo = jj + joffset Bzijj = Bz[jj, i] for ii in 1:S - C[ii + ioffset, jjo] += Az[ii, i] * Bzijj + C[ii + ioffset, jjo] = muladd(Az[ii, i], Bzijj, C[ii + ioffset, jjo]) end end end @@ -416,7 +416,8 @@ function LinearAlgebra.mul!( isone(beta) || rmul!(y, beta) z = A.z @inbounds for (i, r) in enumerate(A.refs) - y[i] += alpha * b[r] * z[i] + # must be muladd and not fma because of potential missings + y[i] = muladd(alpha * b[r], z[i], y[i]) end return y end @@ -446,7 +447,8 @@ function LinearAlgebra.mul!( @inbounds for (i, ii) in enumerate(A.refs) offset = (ii - 1) * k for j in 1:k - y[i] += alpha * Z[j, i] * b[offset + j] + # must be muladd and not fma because of potential missings + y[i] = muladd(alpha * Z[j, i], b[offset + j], y[i]) end end return y @@ -466,7 +468,8 @@ function LinearAlgebra.mul!( isone(beta) || rmul!(y, beta) @inbounds for (i, ii) in enumerate(refarray(A)) for j in 1:k - y[i] += alpha * Z[j, i] * B[j, ii] + # must be muladd and not fma because of potential missings + y[i] = muladd(alpha * Z[j, i], B[j, ii], y[i]) end end return y @@ -566,7 +569,7 @@ function copyscaleinflate!(Ljj::Diagonal{T}, Ajj::Diagonal{T}, Λj::ReMat{T,1}) Ldiag, Adiag = Ljj.diag, Ajj.diag lambsq = abs2(only(Λj.λ.data)) @inbounds for i in eachindex(Ldiag, Adiag) - Ldiag[i] = lambsq * Adiag[i] + one(T) + Ldiag[i] = muladd(lambsq, Adiag[i], one(T)) end return Ljj end @@ -575,7 +578,7 @@ function copyscaleinflate!(Ljj::Matrix{T}, Ajj::Diagonal{T}, Λj::ReMat{T,1}) wh fill!(Ljj, zero(T)) lambsq = abs2(only(Λj.λ.data)) @inbounds for (i, a) in enumerate(Ajj.diag) - Ljj[i, i] = lambsq * a + one(T) + Ljj[i, i] = muladd(lambsq, a, one(T)) end return Ljj end diff --git a/test/pls.jl b/test/pls.jl index 693609a41..6dd8e9583 100644 --- a/test/pls.jl +++ b/test/pls.jl @@ -139,7 +139,7 @@ end vc = fm1.vcov @test isa(vc, Matrix{Float64}) - @test only(vc) ≈ 375.7167775 rtol=1.e-6 + @test only(vc) ≈ 375.7167775 rtol=1.e-3 # since we're caching the fits, we should get it back to being correctly fitted # we also take this opportunity to test fitlog @testset "fitlog" begin