Skip to content

Commit

Permalink
Merge pull request #160 from baggepinnen/large
Browse files Browse the repository at this point in the history
Performance optimizations for large systems
  • Loading branch information
baggepinnen authored Nov 21, 2024
2 parents ed28c3b + c2d7033 commit 1904666
Show file tree
Hide file tree
Showing 9 changed files with 438 additions and 129 deletions.
74 changes: 53 additions & 21 deletions src/ekf.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
abstract type AbstractExtendedKalmanFilter <: AbstractKalmanFilter end
@with_kw struct ExtendedKalmanFilter{KF <: KalmanFilter, F, G, A, C} <: AbstractExtendedKalmanFilter
abstract type AbstractExtendedKalmanFilter{IPD,IPM} <: AbstractKalmanFilter end
@with_kw struct ExtendedKalmanFilter{IPD, IPM, KF <: KalmanFilter, F, G, A, C} <: AbstractExtendedKalmanFilter{IPD,IPM}
kf::KF
dynamics::F
measurement::G
Expand All @@ -8,7 +8,7 @@ abstract type AbstractExtendedKalmanFilter <: AbstractKalmanFilter end
end

"""
ExtendedKalmanFilter(kf, dynamics, measurement)
ExtendedKalmanFilter(kf, dynamics, measurement; Ajac, Cjac)
ExtendedKalmanFilter(dynamics, measurement, R1,R2,d0=MvNormal(Matrix(R1)); nu::Int, p = NullParameters(), α = 1.0, check = true)
A nonlinear state estimator propagating uncertainty using linearization.
Expand Down Expand Up @@ -39,29 +39,49 @@ function ExtendedKalmanFilter(dynamics, measurement, R1,R2,d0=SimpleMvNormal(Mat
x = zeros(T, nx)
u = zeros(T, nu)
end
t = one(T)
if Ajac === nothing
Ajac = (x,u,p,t) -> ForwardDiff.jacobian(x->dynamics(x,u,p,t), x)
end
if Cjac === nothing
Cjac = (x,u,p,t) -> ForwardDiff.jacobian(x->measurement(x,u,p,t), x)
end
A = Ajac(x,u,p,t)
t = zero(T)
A = zeros(nx, nx) # This one is never needed
B = zeros(nx, nu) # This one is never needed
C = Cjac(x,u,p,t)
C = zeros(ny, nx) # This one is never needed
D = zeros(ny, nu) # This one is never needed
kf = KalmanFilter(A,B,C,D,R1,R2,d0; Ts, p, α, check)
return ExtendedKalmanFilter(kf, dynamics, measurement, Ajac, Cjac)

return ExtendedKalmanFilter(kf, dynamics, measurement; Ajac, Cjac)
end

function ExtendedKalmanFilter(kf, dynamics, measurement; Ajac = nothing, Cjac = nothing)
IPD = has_ip(dynamics)
IPM = has_ip(measurement)
if Ajac === nothing
Ajac = (x,u,p,t) -> ForwardDiff.jacobian(x->dynamics(x,u,p,t), x)
# if IPD
# inner! = (xd,x)->dynamics(xd,x,u,p,t)
# out = zeros(eltype(kf.d0), length(kf.x))
# cfg = ForwardDiff.JacobianConfig(inner!, out, x)
# Ajac = (x,u,p,t) -> ForwardDiff.jacobian!((xd,x)->dynamics(xd,x,u,p,t), out, x, cfg, Val(false))
# else
# inner = x->dynamics(x,u,p,t)
# cfg = ForwardDiff.JacobianConfig(inner, kf.x)
# Ajac = (x,u,p,t) -> ForwardDiff.jacobian(x->dynamics(x,u,p,t), x, cfg, Val(false))
# end

if IPD
outx = zeros(eltype(kf.d0), kf.nx)
jacx = zeros(eltype(kf.d0), kf.nx, kf.nx)
Ajac = (x,u,p,t) -> ForwardDiff.jacobian!(jacx, (xd,x)->dynamics(xd,x,u,p,t), outx, x)
else
Ajac = (x,u,p,t) -> ForwardDiff.jacobian(x->dynamics(x,u,p,t), x)
end
end
if Cjac === nothing
Cjac = (x,u,p,t) -> ForwardDiff.jacobian(x->measurement(x,u,p,t), x)
if IPM
outy = zeros(eltype(kf.d0), kf.ny)
jacy = zeros(eltype(kf.d0), kf.ny, kf.nx)
Cjac = (x,u,p,t) -> ForwardDiff.jacobian!(jacy, (y,x)->measurement(y,x,u,p,t), outy, x)
else
Cjac = (x,u,p,t) -> ForwardDiff.jacobian(x->measurement(x,u,p,t), x)
end
end
return ExtendedKalmanFilter(kf, dynamics, measurement, Ajac, Cjac)
return ExtendedKalmanFilter{IPD,IPM,typeof(kf),typeof(dynamics),typeof(measurement),typeof(Ajac),typeof(Cjac)}(kf, dynamics, measurement, Ajac, Cjac)
end

function Base.getproperty(ekf::EKF, s::Symbol) where EKF <: AbstractExtendedKalmanFilter
Expand All @@ -79,10 +99,16 @@ function Base.propertynames(ekf::EKF, private::Bool=false) where EKF <: Abstract
end


function predict!(kf::AbstractExtendedKalmanFilter, u, p = parameters(kf), t::Real = index(kf)*kf.Ts; R1 = get_mat(kf.R1, kf.x, u, p, t), α = kf.α)
function predict!(kf::AbstractExtendedKalmanFilter{IPD}, u, p = parameters(kf), t::Real = index(kf)*kf.Ts; R1 = get_mat(kf.R1, kf.x, u, p, t), α = kf.α) where IPD
@unpack x,R = kf
A = kf.Ajac(x, u, p, t)
kf.x = kf.dynamics(x, u, p, t)
if IPD
xp = similar(x)
kf.dynamics(xp, x, u, p, t)
kf.x = xp
else
kf.x = kf.dynamics(x, u, p, t)
end
if α == 1
kf.R = symmetrize(A*R*A') + R1
else
Expand All @@ -91,10 +117,16 @@ function predict!(kf::AbstractExtendedKalmanFilter, u, p = parameters(kf), t::Re
kf.t += 1
end

function correct!(kf::AbstractExtendedKalmanFilter, u, y, p = parameters(kf), t::Real = index(kf); R2 = get_mat(kf.R2, kf.x, u, p, t))
function correct!(kf::AbstractExtendedKalmanFilter{<:Any, IPM}, u, y, p = parameters(kf), t::Real = index(kf); R2 = get_mat(kf.R2, kf.x, u, p, t)) where IPM
@unpack x,R = kf
C = kf.Cjac(x, u, p, t)
e = y .- kf.measurement(x, u, p, t)
C = kf.Cjac(x, u, p, t)
if IPM
e = zeros(length(y))
kf.measurement(e, x, u, p, t)
e .= y .- e
else
e = y .- kf.measurement(x, u, p, t)
end
S = symmetrize(C*R*C') + R2
Sᵪ = cholesky(S)
K = (R*C')/Sᵪ
Expand Down
19 changes: 13 additions & 6 deletions src/filtering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,26 +47,32 @@ function predict!(kf::AbstractKalmanFilter, u, p=parameters(kf), t::Real = index
Bt = get_mat(B, x, u, p, t)
kf.x = At*x .+ Bt*u |> vec
if α == 1
Ru = symmetrize(At*R*At')
kf.R = Ru + R1
if R isa SMatrix
kf.R = symmetrize(At*R*At') + R1
else
AtR = At*R
mul!(R, AtR, At')
symmetrize(R)
R .+= R1
end
else
Ru = symmetrize*At*R*At')
kf.R = Ru + R1
@bangbang kf.R .= Ru .+ R1
end
kf.t += 1
end

@inline function symmetrize(x::SArray)
x = 0.5 .* (x .+ x')
Symmetric(x)
x
end
@inline function symmetrize(x)
n = size(x,1)
@inbounds for i = 1:n, j = i+1:n
x[i,j] = 0.5 * (x[i,j] + x[j,i])
x[j,i] = x[i,j]
end
Symmetric(x)
x
end

"""
Expand All @@ -85,7 +91,8 @@ function correct!(kf::AbstractKalmanFilter, u, y, p=parameters(kf), t::Real = in
if !iszero(D)
e -= Dt*u
end
S = symmetrize(Ct*R*Ct') + R2
S = symmetrize(Ct*R*Ct')
@bangbang S .+= R2
Sᵪ = cholesky(S)
K = (R*Ct')/Sᵪ
kf.x += K*e
Expand Down
27 changes: 27 additions & 0 deletions src/kalman.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
abstract type AbstractKalmanFilter <: AbstractFilter end

function convert_cov_type(R1, R)
if !(eltype(R) <: AbstractFloat)
R = float.(R)
end
if R isa SMatrix || R isa Matrix
return copy(R)
elseif R1 isa SMatrix && size(R) == size(R1)
Expand Down Expand Up @@ -68,6 +71,9 @@ function KalmanFilter(A,B,C,D,R1,R2,d0=SimpleMvNormal(Matrix(R1)); Ts = 1, p = N
α 1 || @warn "α should be > 1 for exponential forgetting. An α < 1 will lead to exponential loss of adaptation over time."
maximum(abs, eigvals(A isa SMatrix ? Matrix(A) : A)) 2 && @warn "The dynamics matrix A has eigenvalues with absolute value ≥ 2. This is either a highly unstable system, or you have forgotten to discretize a continuous-time model. If you are sure that the system is provided in discrete time, you can disable this warning by setting check=false." maxlog=1
end
if D == 0
D = zeros(eltype(A), size(C,1), size(B,2))
end
R = convert_cov_type(R1, d0.Σ)
x0 = convert_x0_type(d0.μ)
KalmanFilter(A,B,C,D,R1,R2, d0, x0, R, 0, Ts, p, α)
Expand Down Expand Up @@ -108,6 +114,27 @@ function measurement(kf::AbstractKalmanFilter)
end
end

# This helper struct is used to return a oop measurement function regardless of how the measurement function is defined
struct MeasurementOop
kf::AbstractKalmanFilter
end

function (kfm::MeasurementOop)(x,u,p,t)
kf = kfm.kf
mfun = measurement(kf)
if has_ip(mfun)
y = zeros(kf.ny)
mfun(y,x,u,p,t)
return y
else
return mfun(x,u,p,t)
end
end

function measurement_oop(kf::AbstractKalmanFilter)
MeasurementOop(kf)
end

function dynamics(kf::AbstractKalmanFilter)
(x,u,p,t) -> get_mat(kf.A, x, u, p, t)*x + get_mat(kf.B, x, u, p, t)*u
end
Expand Down
19 changes: 17 additions & 2 deletions src/solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ struct KalmanFilteringSolution{F,Tu,Ty,Tx,Txt,TR,TRt,Tll} <: AbstractFilteringSo
ll::Tll
end

@recipe function plot(timevec::AbstractVector{<:Real}, sol::KalmanFilteringSolution; plotx = true, plotxt=true, plotu=true, ploty=true, name = "")
@recipe function plot(timevec::AbstractVector{<:Real}, sol::KalmanFilteringSolution; plotx = true, plotxt=true, plotu=true, ploty=true, plotyh=false, plotyht=true, name = "")
isempty(name) || (name = name*" ")
kf = sol.f
nx, nu, ny = length(sol.x[1]), length(sol.u[1]), length(sol.y[1])
layout --> nx*(plotx || plotxt) + plotu*nu + ploty*ny
layout --> nx*(plotx || plotxt) + plotu*nu + (ploty || plotyh || plotyht)*ny
plotx && @series begin
label --> ["$(name)x$(i)(t|t-1)" for i in 1:nx] |> permutedims
subplot --> (1:nx)'
Expand All @@ -45,6 +46,20 @@ end
subplot --> (1:ny)' .+ (nx*(plotx || plotxt) + nu*plotu)
timevec, reduce(hcat, sol.y)'
end
plotyh && @series begin
label --> ["$(i)(t|t-1)" for i in 1:ny] |> permutedims
subplot --> (1:ny)' .+ (nx*(plotx || plotxt) + nu*plotu)
linestyle --> :dash
yh = measurement_oop(kf).(sol.x, sol.u, Ref(kf.p), timevec)
timevec, reduce(hcat, yh)'
end
plotyht && @series begin
label --> ["$(i)(t|t)" for i in 1:ny] |> permutedims
subplot --> (1:ny)' .+ (nx*(plotx || plotxt) + nu*plotu)
linestyle --> :dash
yht = measurement_oop(kf).(sol.xt, sol.u, Ref(kf.p), timevec)
timevec, reduce(hcat, yht)'
end
end

@recipe function plot(sol::KalmanFilteringSolution)
Expand Down
24 changes: 20 additions & 4 deletions src/sq_kalman.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,7 @@ function correct!(kf::SqKalmanFilter, u, y, p=parameters(kf), t::Real = index(kf
end
S0 = qr([R*Ct';R2]).R
S = UpperTriangular(S0)
if any(<(0), @view(S0[diagind(S0)])) || det(S) < 0 # Cheap for triangular matrices
S0 = -S0 # To avoid log(negative) in logpdf
end
S0 = signdet!(S0, S)
K = ((R'*(R*Ct'))/S)/(S')
kf.x += K*e
M = [R*(I - K*Ct)';R2*K']
Expand All @@ -148,4 +146,22 @@ function correct!(kf::SqKalmanFilter, u, y, p=parameters(kf), t::Real = index(kf
Sᵪ = Cholesky(S0, 'U', 0)
ll = extended_logpdf(SimpleMvNormal(PDMat(SS, Sᵪ)), e)# - 1/2*logdet(S) # logdet is included in logpdf
(; ll, e, SS, Sᵪ, K)
end
end

@inline function signdet!(S0, S)
@inbounds for rc in axes(S0, 1)
# In order to get a well-defined logdet, we need to enforce a positive diagonal of the R factor
if S0[rc,rc] < 0
for c = rc:size(S0, 2)
S0[rc, c] = -S0[rc,c]
end
end
end
S0
end

@inline function signdet!(S0::SMatrix, S)
Stemp = similar(S0) .= S0
signdet!(Stemp, S)
SMatrix(Stemp)
end
Loading

0 comments on commit 1904666

Please sign in to comment.