From d0f24f958aa97ee50e2d2f0621019aae4165c5fe Mon Sep 17 00:00:00 2001 From: Phillip Alday Date: Fri, 26 Jan 2024 13:48:08 -0600 Subject: [PATCH 1/4] use FMA where possible in fitting --- src/linalg.jl | 2 +- src/linalg/rankUpdate.jl | 18 +++++++++--------- src/linearmixedmodel.jl | 4 ++-- src/remat.jl | 30 ++++++++++++++++++------------ test/pls.jl | 2 +- 5 files changed, 31 insertions(+), 25 deletions(-) diff --git a/src/linalg.jl b/src/linalg.jl index 5cf62bce9..59a40478d 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -18,7 +18,7 @@ function LinearAlgebra.mul!( αbnz = α * bnz[ib] jj = brv[ib] for ia in nzrange(A, j) - C[arv[ia], jj] += anz[ia] * αbnz + C[arv[ia], jj] = fma(anz[ia], αbnz, C[arv[ia], jj]) end end end diff --git a/src/linalg/rankUpdate.jl b/src/linalg/rankUpdate.jl index 1d215cb17..a0696746c 100644 --- a/src/linalg/rankUpdate.jl +++ b/src/linalg/rankUpdate.jl @@ -22,7 +22,7 @@ function MixedModels.rankUpdate!( Cdiag = C.data.diag Adiag = A.diag @inbounds for idx in eachindex(Cdiag, Adiag) - Cdiag[idx] = β * Cdiag[idx] + α * abs2(Adiag[idx]) + Cdiag[idx] = fma(β,Cdiag[idx], α * abs2(Adiag[idx])) end return C end @@ -52,7 +52,7 @@ function _columndot(rv, nz, rngi, rngj) while i ≤ ni && j ≤ nj @inbounds ri, rj = rv[rngi[i]], rv[rngj[j]] if ri == rj - @inbounds accum += nz[rngi[i]] * nz[rngj[j]] + @inbounds accum = fma(nz[rngi[i]], nz[rngj[j]], accum) i += 1 j += 1 elseif ri < rj @@ -80,7 +80,7 @@ function rankUpdate!(C::HermOrSym{T,S}, A::SparseMatrixCSC{T}, α, β) where {T, rvj = rv[j] for i in k:lenrngjj kk = rangejj[i] - Cd[rv[kk], rvj] += nz[kk] * anzj + Cd[rv[kk], rvj] = fma(nz[kk], anzj, Cd[rv[kk], rvj]) end end end @@ -88,9 +88,9 @@ function rankUpdate!(C::HermOrSym{T,S}, A::SparseMatrixCSC{T}, α, β) where {T, @inbounds for j in axes(C, 2) rngj = nzrange(A, j) for i in 1:(j - 1) - Cd[i, j] += α * _columndot(rv, nz, nzrange(A, i), rngj) + Cd[i, j] = fma(α, _columndot(rv, nz, nzrange(A, i), rngj), Cd[i, j]) end - Cd[j, j] += α * sum(i -> abs2(nz[i]), rngj) + Cd[j, j] = fma(α, sum(i -> abs2(nz[i]), rngj), Cd[j, j]) end end return C @@ -109,7 +109,7 @@ function rankUpdate!( isone(β) || rmul!(Cdiag, β) @inbounds for i in eachindex(Cdiag) - Cdiag[i] += α * sum(abs2, view(A, i, :)) + Cdiag[i] = fma(α, sum(abs2, view(A, i, :)), Cdiag[i]) end return C @@ -132,9 +132,9 @@ function rankUpdate!( AtAij = 0 for idx in axes(A, 2) # because the second multiplicant is from A', swap index order - AtAij += A[iind, idx] * A[jind, idx] + AtAij = fma(A[iind, idx], A[jind, idx], AtAij) end - Cdat[i, j, k] += α * AtAij + Cdat[i, j, k] = fma(α, AtAij, Cdat[i, j, k]) end end @@ -152,7 +152,7 @@ function rankUpdate!( throw(ArgumentError("Columns of A must have exactly 1 nonzero")) for (r, nz) in zip(rowvals(A), nonzeros(A)) - dd[r] += α * abs2(nz) + dd[r] = fma(α, abs2(nz), dd[r]) end return C diff --git a/src/linearmixedmodel.jl b/src/linearmixedmodel.jl index 74251d6b7..56995f517 100644 --- a/src/linearmixedmodel.jl +++ b/src/linearmixedmodel.jl @@ -767,7 +767,7 @@ function StatsAPI.leverage(m::LinearMixedModel{T}) where {T} z = trm.z stride = size(z, 1) mul!( - view(rhs2, (rhsoffset + (trm.refs[i] - 1) * stride) .+ Base.OneTo(stride)), + view(rhs2, fma((trm.refs[i] - 1), stride, rhsoffset) .+ Base.OneTo(stride)), adjoint(trm.λ), view(z, :, i), ) @@ -816,7 +816,7 @@ function objective(m::LinearMixedModel{T}) where {T} val = if isnothing(σ) logdet(m) + denomdf * (one(T) + log2π + log(pwrss(m) / denomdf)) else - denomdf * (log2π + 2 * log(σ)) + logdet(m) + pwrss(m) / σ^2 + fma(denomdf, fma(2, log(σ), log2π), (logdet(m) + pwrss(m) / σ^2)) end return isempty(wts) ? val : val - T(2.0) * sum(log, wts) end diff --git a/src/remat.jl b/src/remat.jl index a23430190..662e297c2 100644 --- a/src/remat.jl +++ b/src/remat.jl @@ -284,7 +284,7 @@ function LinearAlgebra.mul!( @inbounds for (j, rrj) in enumerate(B.refs) αzj = α * zz[j] for i in 1:p - C[i, rrj] += αzj * Awt[j, i] + C[i, rrj] = fma(αzj, Awt[j, i], C[i, rrj]) end end return C @@ -310,7 +310,7 @@ function LinearAlgebra.mul!( aki = α * Awt[k, i] kk = Int(rr[k]) for ii in 1:S - scr[ii, kk] += aki * Bwt[ii, k] + scr[ii, kk] = fma(aki, Bwt[ii, k], scr[ii, kk]) end end for j in 1:q @@ -340,7 +340,7 @@ function LinearAlgebra.mul!( coljlast = Int(C.colptr[j + 1] - 1) K = searchsortedfirst(rv, i, Int(C.colptr[j]), coljlast, Base.Order.Forward) if K ≤ coljlast && rv[K] == i - nz[K] += Az[k] * Bz[k] + nz[K] = fma(Az[k], Bz[k], nz[K]) else throw(ArgumentError("C does not have the nonzero pattern of A'B")) end @@ -361,7 +361,7 @@ function LinearAlgebra.mul!( @inbounds for i in 1:S zij = Awtz[i, j] for k in 1:S - Cd[k, i, r] += zij * Awtz[k, j] + Cd[k, i, r] = fma(zij, Awtz[k, j], Cd[k, i, r]) end end end @@ -397,7 +397,7 @@ function LinearAlgebra.mul!( jjo = jj + joffset Bzijj = Bz[jj, i] for ii in 1:S - C[ii + ioffset, jjo] += Az[ii, i] * Bzijj + C[ii + ioffset, jjo] = fma(Az[ii, i], Bzijj, C[ii + ioffset, jjo]) end end end @@ -416,7 +416,8 @@ function LinearAlgebra.mul!( isone(beta) || rmul!(y, beta) z = A.z @inbounds for (i, r) in enumerate(A.refs) - y[i] += alpha * b[r] * z[i] + # muladd because explicit fma doesn't work with missings + y[i] = muladd(alpha * b[r], z[i], y[i]) end return y end @@ -446,7 +447,8 @@ function LinearAlgebra.mul!( @inbounds for (i, ii) in enumerate(A.refs) offset = (ii - 1) * k for j in 1:k - y[i] += alpha * Z[j, i] * b[offset + j] + # muladd because explicit fma doesn't work with missings + y[i] = muladd(alpha * Z[j, i], b[offset + j], y[i]) end end return y @@ -466,7 +468,8 @@ function LinearAlgebra.mul!( isone(beta) || rmul!(y, beta) @inbounds for (i, ii) in enumerate(refarray(A)) for j in 1:k - y[i] += alpha * Z[j, i] * B[j, ii] + # muladd because explicit fma doesn't work with missings + y[i] = muladd(alpha * Z[j, i], B[j, ii], y[i]) end end return y @@ -564,7 +567,10 @@ function copyscaleinflate! end function copyscaleinflate!(Ljj::Diagonal{T}, Ajj::Diagonal{T}, Λj::ReMat{T,1}) where {T} Ldiag, Adiag = Ljj.diag, Ajj.diag - broadcast!((x, λsqr) -> x * λsqr + one(T), Ldiag, Adiag, abs2(only(Λj.λ))) + lambsq = abs2(only(Λj.λ.data)) + @inbounds for i in eachindex(Ldiag, Adiag) + Ldiag[i] = fma(lambsq, Adiag[i], one(T)) + end return Ljj end @@ -572,7 +578,7 @@ function copyscaleinflate!(Ljj::Matrix{T}, Ajj::Diagonal{T}, Λj::ReMat{T,1}) wh fill!(Ljj, zero(T)) lambsq = abs2(only(Λj.λ.data)) @inbounds for (i, a) in enumerate(Ajj.diag) - Ljj[i, i] = lambsq * a + one(T) + Ljj[i, i] = fma(lambsq, a, one(T)) end return Ljj end @@ -606,14 +612,14 @@ function copyscaleinflate!( iszero(r) || throw(DimensionMismatch("size(Ljj, 1) is not a multiple of S")) λ = Λj.λ offset = 0 - @inbounds for k in 1:q + @inbounds for _ in 1:q inds = (offset + 1):(offset + S) tmp = view(Ljj, inds, inds) lmul!(adjoint(λ), rmul!(tmp, λ)) offset += S end for k in diagind(Ljj) - Ljj[k] += 1 + Ljj[k] += one(T) end return Ljj end diff --git a/test/pls.jl b/test/pls.jl index 693609a41..6dd8e9583 100644 --- a/test/pls.jl +++ b/test/pls.jl @@ -139,7 +139,7 @@ end vc = fm1.vcov @test isa(vc, Matrix{Float64}) - @test only(vc) ≈ 375.7167775 rtol=1.e-6 + @test only(vc) ≈ 375.7167775 rtol=1.e-3 # since we're caching the fits, we should get it back to being correctly fitted # we also take this opportunity to test fitlog @testset "fitlog" begin From edc54d2cf0af7825be6b018e8eae82789420d706 Mon Sep 17 00:00:00 2001 From: Phillip Alday Date: Tue, 5 Mar 2024 11:18:36 -0600 Subject: [PATCH 2/4] use muladd everywhere --- src/linalg.jl | 2 +- src/linalg/rankUpdate.jl | 18 +++++++++--------- src/linearmixedmodel.jl | 4 ++-- src/remat.jl | 20 ++++++++++---------- 4 files changed, 22 insertions(+), 22 deletions(-) diff --git a/src/linalg.jl b/src/linalg.jl index 59a40478d..630ef0299 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -18,7 +18,7 @@ function LinearAlgebra.mul!( αbnz = α * bnz[ib] jj = brv[ib] for ia in nzrange(A, j) - C[arv[ia], jj] = fma(anz[ia], αbnz, C[arv[ia], jj]) + C[arv[ia], jj] = muladd(anz[ia], αbnz, C[arv[ia], jj]) end end end diff --git a/src/linalg/rankUpdate.jl b/src/linalg/rankUpdate.jl index a0696746c..1b81dd35e 100644 --- a/src/linalg/rankUpdate.jl +++ b/src/linalg/rankUpdate.jl @@ -22,7 +22,7 @@ function MixedModels.rankUpdate!( Cdiag = C.data.diag Adiag = A.diag @inbounds for idx in eachindex(Cdiag, Adiag) - Cdiag[idx] = fma(β,Cdiag[idx], α * abs2(Adiag[idx])) + Cdiag[idx] = muladd(β,Cdiag[idx], α * abs2(Adiag[idx])) end return C end @@ -52,7 +52,7 @@ function _columndot(rv, nz, rngi, rngj) while i ≤ ni && j ≤ nj @inbounds ri, rj = rv[rngi[i]], rv[rngj[j]] if ri == rj - @inbounds accum = fma(nz[rngi[i]], nz[rngj[j]], accum) + @inbounds accum = muladd(nz[rngi[i]], nz[rngj[j]], accum) i += 1 j += 1 elseif ri < rj @@ -80,7 +80,7 @@ function rankUpdate!(C::HermOrSym{T,S}, A::SparseMatrixCSC{T}, α, β) where {T, rvj = rv[j] for i in k:lenrngjj kk = rangejj[i] - Cd[rv[kk], rvj] = fma(nz[kk], anzj, Cd[rv[kk], rvj]) + Cd[rv[kk], rvj] = muladd(nz[kk], anzj, Cd[rv[kk], rvj]) end end end @@ -88,9 +88,9 @@ function rankUpdate!(C::HermOrSym{T,S}, A::SparseMatrixCSC{T}, α, β) where {T, @inbounds for j in axes(C, 2) rngj = nzrange(A, j) for i in 1:(j - 1) - Cd[i, j] = fma(α, _columndot(rv, nz, nzrange(A, i), rngj), Cd[i, j]) + Cd[i, j] = muladd(α, _columndot(rv, nz, nzrange(A, i), rngj), Cd[i, j]) end - Cd[j, j] = fma(α, sum(i -> abs2(nz[i]), rngj), Cd[j, j]) + Cd[j, j] = muladd(α, sum(i -> abs2(nz[i]), rngj), Cd[j, j]) end end return C @@ -109,7 +109,7 @@ function rankUpdate!( isone(β) || rmul!(Cdiag, β) @inbounds for i in eachindex(Cdiag) - Cdiag[i] = fma(α, sum(abs2, view(A, i, :)), Cdiag[i]) + Cdiag[i] = muladd(α, sum(abs2, view(A, i, :)), Cdiag[i]) end return C @@ -132,9 +132,9 @@ function rankUpdate!( AtAij = 0 for idx in axes(A, 2) # because the second multiplicant is from A', swap index order - AtAij = fma(A[iind, idx], A[jind, idx], AtAij) + AtAij = muladd(A[iind, idx], A[jind, idx], AtAij) end - Cdat[i, j, k] = fma(α, AtAij, Cdat[i, j, k]) + Cdat[i, j, k] = muladd(α, AtAij, Cdat[i, j, k]) end end @@ -152,7 +152,7 @@ function rankUpdate!( throw(ArgumentError("Columns of A must have exactly 1 nonzero")) for (r, nz) in zip(rowvals(A), nonzeros(A)) - dd[r] = fma(α, abs2(nz), dd[r]) + dd[r] = muladd(α, abs2(nz), dd[r]) end return C diff --git a/src/linearmixedmodel.jl b/src/linearmixedmodel.jl index 56995f517..92cf3b786 100644 --- a/src/linearmixedmodel.jl +++ b/src/linearmixedmodel.jl @@ -767,7 +767,7 @@ function StatsAPI.leverage(m::LinearMixedModel{T}) where {T} z = trm.z stride = size(z, 1) mul!( - view(rhs2, fma((trm.refs[i] - 1), stride, rhsoffset) .+ Base.OneTo(stride)), + view(rhs2, muladd((trm.refs[i] - 1), stride, rhsoffset) .+ Base.OneTo(stride)), adjoint(trm.λ), view(z, :, i), ) @@ -816,7 +816,7 @@ function objective(m::LinearMixedModel{T}) where {T} val = if isnothing(σ) logdet(m) + denomdf * (one(T) + log2π + log(pwrss(m) / denomdf)) else - fma(denomdf, fma(2, log(σ), log2π), (logdet(m) + pwrss(m) / σ^2)) + muladd(denomdf, muladd(2, log(σ), log2π), (logdet(m) + pwrss(m) / σ^2)) end return isempty(wts) ? val : val - T(2.0) * sum(log, wts) end diff --git a/src/remat.jl b/src/remat.jl index 662e297c2..74d92ff88 100644 --- a/src/remat.jl +++ b/src/remat.jl @@ -284,7 +284,7 @@ function LinearAlgebra.mul!( @inbounds for (j, rrj) in enumerate(B.refs) αzj = α * zz[j] for i in 1:p - C[i, rrj] = fma(αzj, Awt[j, i], C[i, rrj]) + C[i, rrj] = muladd(αzj, Awt[j, i], C[i, rrj]) end end return C @@ -310,7 +310,7 @@ function LinearAlgebra.mul!( aki = α * Awt[k, i] kk = Int(rr[k]) for ii in 1:S - scr[ii, kk] = fma(aki, Bwt[ii, k], scr[ii, kk]) + scr[ii, kk] = muladd(aki, Bwt[ii, k], scr[ii, kk]) end end for j in 1:q @@ -340,7 +340,7 @@ function LinearAlgebra.mul!( coljlast = Int(C.colptr[j + 1] - 1) K = searchsortedfirst(rv, i, Int(C.colptr[j]), coljlast, Base.Order.Forward) if K ≤ coljlast && rv[K] == i - nz[K] = fma(Az[k], Bz[k], nz[K]) + nz[K] = muladd(Az[k], Bz[k], nz[K]) else throw(ArgumentError("C does not have the nonzero pattern of A'B")) end @@ -361,7 +361,7 @@ function LinearAlgebra.mul!( @inbounds for i in 1:S zij = Awtz[i, j] for k in 1:S - Cd[k, i, r] = fma(zij, Awtz[k, j], Cd[k, i, r]) + Cd[k, i, r] = muladd(zij, Awtz[k, j], Cd[k, i, r]) end end end @@ -397,7 +397,7 @@ function LinearAlgebra.mul!( jjo = jj + joffset Bzijj = Bz[jj, i] for ii in 1:S - C[ii + ioffset, jjo] = fma(Az[ii, i], Bzijj, C[ii + ioffset, jjo]) + C[ii + ioffset, jjo] = muladd(Az[ii, i], Bzijj, C[ii + ioffset, jjo]) end end end @@ -416,7 +416,7 @@ function LinearAlgebra.mul!( isone(beta) || rmul!(y, beta) z = A.z @inbounds for (i, r) in enumerate(A.refs) - # muladd because explicit fma doesn't work with missings + # must be muladd and not fma because of potential missings y[i] = muladd(alpha * b[r], z[i], y[i]) end return y @@ -447,7 +447,7 @@ function LinearAlgebra.mul!( @inbounds for (i, ii) in enumerate(A.refs) offset = (ii - 1) * k for j in 1:k - # muladd because explicit fma doesn't work with missings + # must be muladd and not fma because of potential missings y[i] = muladd(alpha * Z[j, i], b[offset + j], y[i]) end end @@ -468,7 +468,7 @@ function LinearAlgebra.mul!( isone(beta) || rmul!(y, beta) @inbounds for (i, ii) in enumerate(refarray(A)) for j in 1:k - # muladd because explicit fma doesn't work with missings + # must be muladd and not fma because of potential missings y[i] = muladd(alpha * Z[j, i], B[j, ii], y[i]) end end @@ -569,7 +569,7 @@ function copyscaleinflate!(Ljj::Diagonal{T}, Ajj::Diagonal{T}, Λj::ReMat{T,1}) Ldiag, Adiag = Ljj.diag, Ajj.diag lambsq = abs2(only(Λj.λ.data)) @inbounds for i in eachindex(Ldiag, Adiag) - Ldiag[i] = fma(lambsq, Adiag[i], one(T)) + Ldiag[i] = muladd(lambsq, Adiag[i], one(T)) end return Ljj end @@ -578,7 +578,7 @@ function copyscaleinflate!(Ljj::Matrix{T}, Ajj::Diagonal{T}, Λj::ReMat{T,1}) wh fill!(Ljj, zero(T)) lambsq = abs2(only(Λj.λ.data)) @inbounds for (i, a) in enumerate(Ajj.diag) - Ljj[i, i] = fma(lambsq, a, one(T)) + Ljj[i, i] = muladd(lambsq, a, one(T)) end return Ljj end From a9873f67ffc269f4bea3d4b4e188177fba30f952 Mon Sep 17 00:00:00 2001 From: Phillip Alday Date: Tue, 5 Mar 2024 11:21:14 -0600 Subject: [PATCH 3/4] NEWS update --- NEWS.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/NEWS.md b/NEWS.md index f0c4aefd7..f9a543308 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,5 @@ +* Use `muladd` where possible to enable fused multiply-add (FMA) on architectures with hardware support. FMA will generally improve computational speed and gives more accurate rounding. [#740] + MixedModels v4.22.4 Release Notes ============================== * Switch to explicit imports from all included packages (i.e. replace `using Foo` by `using Foo: Foo, bar, baz`) [#748] @@ -495,5 +497,6 @@ Package dependencies [#715]: https://github.com/JuliaStats/MixedModels.jl/issues/715 [#717]: https://github.com/JuliaStats/MixedModels.jl/issues/717 [#733]: https://github.com/JuliaStats/MixedModels.jl/issues/733 +[#740]: https://github.com/JuliaStats/MixedModels.jl/issues/740 [#744]: https://github.com/JuliaStats/MixedModels.jl/issues/744 [#748]: https://github.com/JuliaStats/MixedModels.jl/issues/748 From d46079df0de741c0c6a7b86166794e73849a7807 Mon Sep 17 00:00:00 2001 From: Phillip Alday Date: Tue, 5 Mar 2024 11:29:31 -0600 Subject: [PATCH 4/4] format --- src/linalg.jl | 2 +- src/linalg/rankUpdate.jl | 2 +- src/linearmixedmodel.jl | 4 +++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/linalg.jl b/src/linalg.jl index 630ef0299..91e628153 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -18,7 +18,7 @@ function LinearAlgebra.mul!( αbnz = α * bnz[ib] jj = brv[ib] for ia in nzrange(A, j) - C[arv[ia], jj] = muladd(anz[ia], αbnz, C[arv[ia], jj]) + C[arv[ia], jj] = muladd(anz[ia], αbnz, C[arv[ia], jj]) end end end diff --git a/src/linalg/rankUpdate.jl b/src/linalg/rankUpdate.jl index 1b81dd35e..e6f1aacd6 100644 --- a/src/linalg/rankUpdate.jl +++ b/src/linalg/rankUpdate.jl @@ -22,7 +22,7 @@ function MixedModels.rankUpdate!( Cdiag = C.data.diag Adiag = A.diag @inbounds for idx in eachindex(Cdiag, Adiag) - Cdiag[idx] = muladd(β,Cdiag[idx], α * abs2(Adiag[idx])) + Cdiag[idx] = muladd(β, Cdiag[idx], α * abs2(Adiag[idx])) end return C end diff --git a/src/linearmixedmodel.jl b/src/linearmixedmodel.jl index 92cf3b786..3827f41ec 100644 --- a/src/linearmixedmodel.jl +++ b/src/linearmixedmodel.jl @@ -767,7 +767,9 @@ function StatsAPI.leverage(m::LinearMixedModel{T}) where {T} z = trm.z stride = size(z, 1) mul!( - view(rhs2, muladd((trm.refs[i] - 1), stride, rhsoffset) .+ Base.OneTo(stride)), + view( + rhs2, muladd((trm.refs[i] - 1), stride, rhsoffset) .+ Base.OneTo(stride) + ), adjoint(trm.λ), view(z, :, i), )